提交 cb284283 编写于 作者: D dangqingqing

Replace LoDTensor in elementwise_mul_op, pad_op and recurrent_op_utils.

上级 30a58b51
......@@ -189,13 +189,7 @@ void OperatorBase::GenerateTemporaryNames() {
template <>
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name);
if (var == nullptr) return nullptr;
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
return var == nullptr ? nullptr : GetTensorFromVar(var);
}
template <>
......@@ -204,9 +198,11 @@ const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
auto names = op().Inputs(name);
std::vector<const Tensor*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { return Input<Tensor>(sub_name); });
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : GetTensorFromVar(var);
});
return res;
}
......@@ -214,12 +210,7 @@ template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto* var = OutputVar(name);
if (var == nullptr) return nullptr;
if (var->IsType<LoDTensor>()) {
return const_cast<LoDTensor*>(&var->Get<LoDTensor>());
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return const_cast<Tensor*>(&var->Get<Tensor>());
return GetTensorFromVar(var);
}
template <>
......@@ -228,9 +219,11 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
auto names = op().Outputs(name);
std::vector<Tensor*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { return Output<Tensor>(sub_name); });
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope().FindVar(sub_name);
return var == nullptr ? nullptr : GetTensorFromVar(var);
});
return res;
}
......
......@@ -306,9 +306,11 @@ class InferShapeContext {
auto names = op_.Inputs(name);
std::vector<const T*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { return Input<T>(sub_name); });
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : &var->Get<T>();
});
return res;
}
......@@ -317,12 +319,23 @@ class InferShapeContext {
auto names = op_.Outputs(name);
std::vector<T*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { return Output<T>(sub_name); });
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : var->GetMutable<T>();
});
return res;
}
Tensor* GetTensorFromVar(const Variable* var) const {
if (var->IsType<LoDTensor>()) {
return const_cast<LoDTensor*>(&var->Get<LoDTensor>());
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return const_cast<Tensor*>(&var->Get<Tensor>());
}
private:
const OperatorBase& op_;
const Scope& scope_;
......
......@@ -31,7 +31,7 @@ class ElementWiseMulOp : public framework::OperatorWithKernel {
auto y_dim = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.")
ctx.Output<Tensor>("Out")->Resize(x_dim);
ctx.Output<framework::Tensor>("Out")->Resize(x_dim);
}
};
......@@ -80,8 +80,8 @@ class ElementWiseMulOpGrad : public framework::OperatorWithKernel {
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.")
......
......@@ -34,7 +34,8 @@ class PadOp : public framework::OperatorWithKernel {
for (int i = 0; i < x_dim.size(); ++i) {
out_dims[i] = x_dim[i] + paddings[i * 2] + paddings[i * 2 + 1];
}
ctx.Output<Tensor>("Out")->Resize(framework::make_ddim(out_dims));
ctx.Output<framework::LoDTensor>("Out")->Resize(
framework::make_ddim(out_dims));
}
};
......@@ -95,9 +96,9 @@ class PadOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
if (x_grad != nullptr) {
x_grad->Resize(x_dims);
auto *x_g = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (x_g != nullptr) {
x_g->Resize(x_dims);
}
}
};
......
......@@ -21,6 +21,7 @@ namespace rnn {
namespace f = paddle::framework;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<Link>& inlinks, const size_t seq_len,
......@@ -31,7 +32,7 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
PADDLE_ENFORCE(input_var != nullptr, "input link [%s] is not in scope.",
inlinks[i].external);
Tensor* input = input_var->GetMutable<Tensor>();
LoDTensor* input = input_var->GetMutable<LoDTensor>();
f::DDim dims = input->dims();
PADDLE_ENFORCE(static_cast<size_t>(dims[0]) == seq_len,
"all the inlinks must have same length");
......@@ -40,6 +41,8 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i].internal)->GetMutable<Tensor>();
if (!infer_shape_mode) {
// The input of operators of each step is Tensor here.
// Maybe need to modify Slice function.
*step_input = input->Slice<float>(j, j + 1);
}
step_input->Resize(step_dims);
......@@ -54,21 +57,23 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
auto output_var = step_scopes[0]->FindVar(outlinks[i].external);
PADDLE_ENFORCE(output_var != nullptr, "output link [%s] is not in scope.",
outlinks[i].external);
Tensor* output = output_var->GetMutable<Tensor>();
LoDTensor* output = output_var->GetMutable<LoDTensor>();
if (infer_shape_mode) {
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i].internal);
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
outlinks[i].internal);
f::DDim step_dims = step_scope_var->template GetMutable<Tensor>()->dims();
f::DDim step_dims =
step_scope_var->template GetMutable<LoDTensor>()->dims();
std::vector<int64_t> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(f::make_ddim(dims_vec));
} else {
output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_output =
step_scopes[j]->FindVar(outlinks[i].internal)->GetMutable<Tensor>();
LoDTensor* step_output = step_scopes[j]
->FindVar(outlinks[i].internal)
->GetMutable<LoDTensor>();
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(output->Slice<float>(j, j + 1))
......@@ -94,8 +99,8 @@ void LinkMemories(const std::vector<Scope*>& scopes,
auto scope = scopes[step_id];
auto linked_scope = scopes[step_id + offset];
for (auto& attr : memories) {
auto mem = scope->FindVar(attr.pre_var)->GetMutable<Tensor>();
auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<Tensor>();
auto mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>();
auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>();
if (infer_shape_mode) {
mem->Resize(linked_mem->dims());
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册