提交 f2992063 编写于 作者: D dangqingqing

Using LoDTensor instead of Tensor in every operator.

上级 d11430e0
......@@ -59,7 +59,7 @@ class LoDTensor : public Tensor {
void set_lod(const LoD& lod) { lod_ = lod; }
LoD lod() { return lod_; }
LoD lod() const { return lod_; }
/*
* Get a element from LoD.
......
......@@ -186,6 +186,54 @@ void OperatorBase::GenerateTemporaryNames() {
}
}
template <>
const Tensor* InferShapeContext::Input<Tensor>(const std::string& name) const {
auto* var = InputVar(name);
if (var == nullptr) return nullptr;
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
}
template <>
const std::vector<const Tensor*> InferShapeContext::MultiInput<Tensor>(
const std::string& name) const {
auto names = op().Inputs(name);
std::vector<const Tensor*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { return Input<Tensor>(sub_name); });
return res;
}
template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
auto* var = OutputVar(name);
if (var == nullptr) return nullptr;
if (var->IsType<LoDTensor>()) {
return const_cast<LoDTensor*>(&var->Get<LoDTensor>());
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return const_cast<Tensor*>(&var->Get<Tensor>());
}
template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const {
auto names = op().Outputs(name);
std::vector<Tensor*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { return Output<Tensor>(sub_name); });
return res;
}
void OpProtoAndCheckerMaker::Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
......
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "op_info.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
......@@ -305,11 +306,9 @@ class InferShapeContext {
auto names = op_.Inputs(name);
std::vector<const T*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : &var->Get<T>();
});
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { return Input<T>(sub_name); });
return res;
}
......@@ -318,11 +317,9 @@ class InferShapeContext {
auto names = op_.Outputs(name);
std::vector<T*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr : var->GetMutable<T>();
});
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { return Output<T>(sub_name); });
return res;
}
......@@ -363,6 +360,27 @@ class ExecutionContext : public InferShapeContext {
return device_context_;
}
// redefine Output function,
// use Variable::Get instead of Variable::GetMutable
template <typename T>
T* Output(const std::string& name) const {
auto var = OutputVar(name);
return var == nullptr ? nullptr : const_cast<T*>(&var->Get<T>());
}
// redefine MultiOutput function.
// use Variable::Get instead of Variable::GetMutable
template <typename T>
std::vector<T*> MultiOutput(const std::string& name) const {
auto names = op().Outputs(name);
std::vector<T*> res;
res.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) { return Output<T>(sub_name); });
return res;
}
const platform::DeviceContext* device_context_;
};
......
......@@ -16,8 +16,6 @@ limitations under the License. */
#include "paddle/memory/memcpy.h"
#include "paddle/platform/enforce.h"
#include <glog/logging.h>
namespace paddle {
namespace framework {
......@@ -55,7 +53,6 @@ inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
template <typename T>
inline T* Tensor::mutable_data(platform::Place place) {
LOG(INFO) << "------ mutable_data ---- ";
static_assert(std::is_pod<T>::value, "T must be POD");
PADDLE_ENFORCE_GT(numel(), 0,
"Tensor's numel must be larger than zero to call "
......@@ -145,7 +142,6 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
}
inline Tensor& Tensor::Resize(const DDim& dims) {
LOG(INFO) << "---- resize -----";
dims_ = dims;
numel_ = product(dims_);
return *this;
......
......@@ -26,7 +26,8 @@ class AddOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
ctx.Input<Tensor>("Y")->dims(),
"Two input of Add Op's dimension must be same.");
ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("X")->dims());
ctx.Output<framework::LoDTensor>("Out")->Resize(
ctx.Input<Tensor>("X")->dims());
}
};
......
......@@ -26,7 +26,7 @@ class ConcatOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto *out = ctx.Output<framework::Tensor>("Out");
auto *out = ctx.Output<framework::LoDTensor>("Out");
size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
size_t n = ins.size();
......
......@@ -32,9 +32,9 @@ class CosSimOp : public framework::OperatorWithKernel {
"Dimensions of Input(X) and Input(Y) must be the same.");
auto dims = ctx.Input<Tensor>("X")->dims();
ctx.Output<Tensor>("Out")->Resize({dims[0], 1});
ctx.Output<Tensor>("XNorm")->Resize({dims[0], 1});
ctx.Output<Tensor>("YNorm")->Resize({dims[0], 1});
ctx.Output<framework::LoDTensor>("Out")->Resize({dims[0], 1});
ctx.Output<framework::LoDTensor>("XNorm")->Resize({dims[0], 1});
ctx.Output<framework::LoDTensor>("YNorm")->Resize({dims[0], 1});
}
};
......@@ -88,8 +88,10 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
"1st dimension of Out@GRAD must equal that of Input(X)");
PADDLE_ENFORCE_EQ(out_dims[1], 1, "1st dimension of Out@GRAD must be one.");
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto *x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *y_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
}
......
......@@ -29,7 +29,7 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1.");
PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]);
ctx.Output<Tensor>("Y")->Resize({X->dims()[0]});
ctx.Output<framework::LoDTensor>("Y")->Resize({X->dims()[0]});
}
};
......@@ -39,7 +39,7 @@ class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dX = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
dX->Resize(X->dims());
......
......@@ -23,7 +23,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::Tensor>("Dst")->Resize(
ctx.Output<framework::LoDTensor>("Dst")->Resize(
ctx.Input<framework::Tensor>("Src")->dims());
}
};
......
......@@ -28,7 +28,7 @@ class GatherOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GE(batch_size, 0, "Batch size must be >0");
framework::DDim output_dims(ctx.Input<Tensor>("X")->dims());
output_dims[0] = batch_size;
ctx.Output<Tensor>("Out")->Resize(output_dims);
ctx.Output<framework::LoDTensor>("Out")->Resize(output_dims);
}
};
......@@ -38,7 +38,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto X_grad = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto X = ctx.Input<Tensor>("X");
X_grad->Resize(X->dims());
......
......@@ -44,7 +44,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
auto* tensor = context.Output<framework::LoDTensor>("Out");
auto dims = Attr<std::vector<int>>("dims");
std::vector<int64_t> temp;
temp.reserve(dims.size());
......
......@@ -25,7 +25,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &context) const override {
auto table_t = context.Input<Tensor>("W");
auto ids_t = context.Input<Tensor>("Ids");
auto output_t = context.Output<Tensor>("Out");
auto output_t = context.Output<framework::LoDTensor>("Out");
output_t->Resize({ids_t->dims()[0], table_t->dims()[1]});
}
......@@ -56,7 +56,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &context) const override {
auto table = context.Input<Tensor>("W");
auto d_table = context.Output<Tensor>(framework::GradVarName("W"));
auto d_table =
context.Output<framework::LoDTensor>(framework::GradVarName("W"));
d_table->Resize(table->dims());
}
};
......
......@@ -25,7 +25,7 @@ class MeanOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input of MeanOp must be initialized.");
ctx.Output<Tensor>("Out")->Resize({1});
ctx.Output<framework::LoDTensor>("Out")->Resize({1});
}
};
......@@ -45,7 +45,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X"))
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("X")->dims());
}
};
......
......@@ -33,7 +33,7 @@ class MinusOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
left_tensor->numel(), right_tensor->numel(),
"Minus operator must take two tensor with same num of elements");
ctx.Output<framework::Tensor>("Out")->Resize(left_tensor->dims());
ctx.Output<framework::LoDTensor>("Out")->Resize(left_tensor->dims());
}
};
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace operators {
using framework::Tensor;
using framework::LoDTensor;
class MulOp : public framework::OperatorWithKernel {
public:
......@@ -45,7 +46,8 @@ class MulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
x_mat_dims[1], y_mat_dims[0],
"First matrix's width must be equal with second matrix's height.");
ctx.Output<Tensor>("Out")->Resize({x_mat_dims[0], y_mat_dims[1]});
ctx.Output<framework::LoDTensor>("Out")->Resize(
{x_mat_dims[0], y_mat_dims[1]});
}
};
......@@ -94,8 +96,10 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto *x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *y_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
auto x_mat_dims =
framework::flatten_to_2d(x_dims, Attr<int>("x_num_col_dims"));
......
......@@ -26,10 +26,11 @@ namespace operators {
using Scope = framework::Scope;
using Variable = framework::Variable;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
void RecurrentAlgorithm::InferShape(const Scope& scope) const {
seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
->GetMutable<Tensor>()
->GetMutable<LoDTensor>()
->dims()[0];
CreateScopes(scope);
auto step_scopes = GetStepScopes(scope);
......@@ -88,7 +89,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
// the weight are located in parent scope
for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) {
step_scope.NewVar(var_name)->GetMutable<Tensor>();
step_scope.NewVar(var_name)->GetMutable<LoDTensor>();
}
}
}
......@@ -106,11 +107,12 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
void RecurrentAlgorithm::InitMemories(Scope* step_scope,
bool infer_shape_mode) const {
for (auto& attr : arg_->memories) {
Tensor* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<Tensor>();
auto* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<LoDTensor>();
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"memory [%s]'s boot variable [%s] not exists", attr.var,
attr.boot_var);
Tensor* boot_mem = step_scope->FindVar(attr.boot_var)->GetMutable<Tensor>();
auto* boot_mem =
step_scope->FindVar(attr.boot_var)->GetMutable<LoDTensor>();
if (infer_shape_mode) {
pre_mem->Resize(boot_mem->dims());
PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2);
......@@ -192,9 +194,9 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
"memory variable [%s] does not exists", attr.var);
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"boot variable [%s] does not exists", attr.boot_var);
Tensor* mem_grad = step_scope->NewVar(attr.var)->GetMutable<Tensor>();
Tensor* boot_mem_grad =
step_scope->NewVar(attr.boot_var)->GetMutable<Tensor>();
auto* mem_grad = step_scope->NewVar(attr.var)->GetMutable<LoDTensor>();
auto* boot_mem_grad =
step_scope->NewVar(attr.boot_var)->GetMutable<LoDTensor>();
if (infer_shape_mode) {
boot_mem_grad->Resize(mem_grad->dims());
} else {
......@@ -205,7 +207,7 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
seq_len_ = scope.FindVar((arg_->inlinks[0]).external)
->GetMutable<Tensor>()
->GetMutable<LoDTensor>()
->dims()[0];
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
......
......@@ -46,7 +46,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto out_dims = framework::make_ddim(shape_int64);
ctx.Output<framework::Tensor>("Out")->Resize(out_dims);
ctx.Output<framework::LoDTensor>("Out")->Resize(out_dims);
}
};
......@@ -90,7 +90,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto dims = ctx.Input<framework::Tensor>("X")->dims();
auto *d_in = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto *d_in = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
d_in->Resize(dims);
}
};
......
......@@ -37,7 +37,7 @@ class RowwiseAddOp : public framework::OperatorWithKernel {
framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
"The width of two operands must be same");
PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1");
ctx.Output<Tensor>("Out")->Resize(x_dims);
ctx.Output<framework::LoDTensor>("Out")->Resize(x_dims);
}
};
......@@ -76,8 +76,8 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
"The width of two operands must be same");
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *db = ctx.Output<Tensor>(framework::GradVarName("b"));
auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *db = ctx.Output<framework::LoDTensor>(framework::GradVarName("b"));
if (dx) dx->Resize(x_dims);
if (db) db->Resize(b_dims);
}
......
......@@ -28,7 +28,7 @@ class ScaleOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto *in = ctx.Input<framework::Tensor>("X");
auto *out = ctx.Output<framework::Tensor>("Out");
auto *out = ctx.Output<framework::LoDTensor>("Out");
out->Resize(in->dims());
}
};
......
......@@ -35,7 +35,8 @@ class ScatterOp : public framework::OperatorWithKernel {
framework::DDim data_dim(ctx.Input<Tensor>("Updates")->dims());
for (int i = 1; i < data_dim.size(); ++i)
PADDLE_ENFORCE_EQ(data_dim[i], ctx.Input<Tensor>("Updates")->dims()[i]);
ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("Ref")->dims());
ctx.Output<framework::LoDTensor>("Out")->Resize(
ctx.Input<Tensor>("Ref")->dims());
}
};
......@@ -45,9 +46,11 @@ class ScatterGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto *dUpdates = ctx.Output<Tensor>(framework::GradVarName("Updates"));
auto *dUpdates =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Updates"));
auto *Updates = ctx.Input<Tensor>("Updates");
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
auto *dRef =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Ref"));
auto *Ref = ctx.Input<Tensor>("Ref");
dRef->Resize(Ref->dims());
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_avg_pool_op.h"
namespace paddle {
namespace operators {
class SequenceAvgPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input of SequenceAvgPoolOp"
"must be initialized.");
auto* x = ctx.Input<framework::LoDTensor>("X");
auto dims = x->dims();
auto lod = x->lod();
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_GE(
dims[0],
/*batch size = */ static_cast<int64_t>(lod[0].size() - 1),
"The first dimension of Input(X) must be large than batch size.");
dims[0] = lod[0].size() - 1;
ctx.Output<framework::LoDTensor>("Out")->Resize({dims});
}
};
class SequenceAvgPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceAvgPoolOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SequenceAvgPoolOp.");
AddOutput("Out", "The output of SequenceAvgPoolOp.");
AddComment(R"DOC(
SequenceAvgPoolOp averages features of all time-steps of each instance.
More detailed comments will be added later.
)DOC");
}
};
class SequenceAvgPoolGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Gradient of Out should not be null");
auto og_dims =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->dims();
auto x_dims = ctx.Input<framework::LoDTensor>("X")->dims();
PADDLE_ENFORCE_EQ(og_dims.size(), x_dims.size(),
"The rank of output grad must equal to Input(X).");
for (size_t i = 1; i < og_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(og_dims[i], x_dims[i], "The dimension mismatch.");
}
auto* x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
x_grad->Resize(x_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sequence_avg_pool, ops::SequenceAvgPoolOp,
ops::SequenceAvgPoolOpMaker, sequence_avg_pool_grad,
ops::SequenceAvgPoolGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_avg_pool,
ops::SequenceAvgPoolKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sequence_avg_pool_grad,
ops::SequenceAvgPoolGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/sequence_avg_pool_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sequence_avg_pool,
ops::SequenceAvgPoolKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sequence_avg_pool_grad,
ops::SequenceAvgPoolGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class SequenceAvgPoolKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
auto dims = in->dims();
auto lod = in->lod();
int64_t w = in->numel() / dims[0];
out->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < lod[0].size() - 1; ++i) {
Tensor in_t = in->Slice<T>(static_cast<int>(lod[0][i]),
static_cast<int>(lod[0][i + 1]));
Tensor out_t = out->Slice<T>(i, i + 1);
int64_t h = static_cast<int64_t>(lod[0][i + 1] - lod[0][i]);
auto in_e = EigenMatrix<T>::From(in_t, {h, w});
auto out_e = EigenMatrix<T>::From(out_t, {h, w});
out_e.device(place) = in_e.mean(Eigen::array<int, 1>({{0}}));
}
}
};
template <typename Place, typename T>
class SequenceAvgPoolGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Output<LoDTensor>("X");
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto dims = in->dims();
auto lod = in->lod();
int64_t w = in->numel() / dims[0];
in_g->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>();
for (int i = 0; i < lod[0].size() - 1; ++i) {
auto in_g_t = in_g->Slice<T>(static_cast<int>(lod[0][i]),
static_cast<int>(lod[0][i + 1]));
auto out_g_t = out_g->Slice<T>(i, i + 1);
int64_t h = static_cast<int64_t>(lod[0][i + 1] - lod[0][i]);
auto in_g_e = EigenMatrix<T>::From(in_g_t, {h, w});
auto out_g_e = EigenMatrix<T>::From(out_g_t, {1, w});
Eigen::DSizes<int, 2> bcast(h, w);
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -23,10 +23,11 @@ class SGDOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(
ctx.Input<Tensor>("param")->dims() == ctx.Input<Tensor>("grad")->dims(),
"Two input of SGD Op's dimension must be same.");
ctx.Output<Tensor>("param_out")->Resize(ctx.Input<Tensor>("param")->dims());
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("param")->dims(),
ctx.Input<Tensor>("grad")->dims(),
"Two input of SGD Op's dimension must be same.");
ctx.Output<framework::LoDTensor>("param_out")
->Resize(ctx.Input<Tensor>("param")->dims());
}
};
......
......@@ -23,7 +23,8 @@ class SigmoidOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
ctx.Output<framework::LoDTensor>("Y")->Resize(
ctx.Input<Tensor>("X")->dims());
}
};
......@@ -44,7 +45,7 @@ class SigmoidOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<Tensor>(framework::GradVarName("X"))
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("Y")->dims());
}
};
......
......@@ -25,7 +25,8 @@ class SoftmaxOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
"The input of softmax op must be a matrix.");
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
ctx.Output<framework::LoDTensor>("Y")->Resize(
ctx.Input<Tensor>("X")->dims());
}
};
......@@ -71,7 +72,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
ctx.Input<Tensor>(framework::GradVarName("Y"))->dims(),
"Input(Y) and its gradients should have a same shape.");
ctx.Output<Tensor>(framework::GradVarName("X"))
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("X")->dims());
}
};
......
......@@ -48,9 +48,9 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
"First dimension of target must be equal to input "
"or to 1.");
ctx.Output<Tensor>("sub_result")
ctx.Output<framework::LoDTensor>("sub_result")
->Resize({x_dims[0], x->numel() / x_dims[0]});
ctx.Output<Tensor>("Out")->Resize({x_dims[0], 1});
ctx.Output<framework::LoDTensor>("Out")->Resize({x_dims[0], 1});
}
};
......@@ -94,8 +94,10 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(out_dims[1], 1,
"Second dimension of output gradient "
"must be 1.");
auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* x_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* y_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Y"));
if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
}
......
......@@ -23,7 +23,7 @@ class SumOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto *out = ctx.Output<framework::Tensor>("Out");
auto *out = ctx.Output<framework::LoDTensor>("Out");
int N = ins.size();
auto in_dim = ins[0]->dims();
......@@ -55,7 +55,8 @@ class SumGradOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto outputs = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
auto outputs =
ctx.MultiOutput<framework::LoDTensor>(framework::GradVarName("X"));
auto dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
for (auto output : outputs) {
output->Resize(dims);
......
......@@ -35,8 +35,8 @@ class TopkOp : public framework::OperatorWithKernel {
framework::DDim dims = input->dims();
dims[dims.size() - 1] = k;
ctx.Output<Tensor>("Out")->Resize(dims);
ctx.Output<Tensor>("Indices")->Resize(dims);
ctx.Output<framework::LoDTensor>("Out")->Resize(dims);
ctx.Output<framework::LoDTensor>("Indices")->Resize(dims);
}
};
......
......@@ -50,7 +50,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(Attr<float>("min") < Attr<float>("max"),
"uniform_random's min must less then max");
auto* tensor = ctx.Output<framework::Tensor>("Out");
auto* tensor = ctx.Output<framework::LoDTensor>("Out");
auto dims = Attr<std::vector<int>>("dims");
std::vector<int64_t> temp;
temp.reserve(dims.size());
......
......@@ -122,6 +122,8 @@ PYBIND11_PLUGIN(core) {
});
py::class_<LoDTensor, Tensor>(m, "LoDTensor")
.def_buffer(
[](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); })
.def(
"__init__",
[](LoDTensor &instance, const std::vector<std::vector<size_t>> &lod) {
......@@ -172,10 +174,11 @@ All parameter, weight, gradient are variables in Paddle.
.def("set_int",
[](Variable &var, int val) -> void { *var.GetMutable<int>() = val; })
.def("get_int", [](const Variable &var) -> int { return var.Get<int>(); })
// .def("get_tensor",
// [](Variable &self) -> Tensor * { return
// self.GetMutable<Tensor>(); },
// py::return_value_policy::reference)
.def("get_tensor",
[](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); },
py::return_value_policy::reference)
.def("get_lod_tensor",
[](Variable &self) -> LoDTensor * {
return self.GetMutable<LoDTensor>();
},
......
......@@ -42,7 +42,6 @@ template <size_t I, typename... ARGS>
struct CastToPyBufferImpl<true, I, ARGS...> {
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
py::buffer_info operator()(framework::Tensor &tensor) {
LOG(INFO) << "---- CastToPyBufferImpl -----";
if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) {
auto dim_vec = framework::vectorize(tensor.dims());
std::vector<size_t> dims_outside;
......
......@@ -4,7 +4,7 @@ import numpy
class TestTensor(unittest.TestCase):
def not_test_int_tensor(self):
def test_int_tensor(self):
scope = core.Scope()
var = scope.new_var("test_tensor")
place = core.CPUPlace()
......@@ -23,7 +23,7 @@ class TestTensor(unittest.TestCase):
self.assertEqual(1, tensor_array_2[3, 9])
self.assertEqual(2, tensor_array_2[19, 11])
def not_test_float_tensor(self):
def test_float_tensor(self):
scope = core.Scope()
var = scope.new_var("test_tensor")
place = core.CPUPlace()
......@@ -44,82 +44,66 @@ class TestTensor(unittest.TestCase):
self.assertAlmostEqual(2.0, tensor_array_2[19, 11])
def test_int_lod_tensor(self):
places = [core.CPUPlace(), core.GPUPlace(0)]
for place in places:
scope = core.Scope()
#var = scope.new_var("test_tensor")
var_lod = scope.new_var("test_lod_tensor")
# tensor = var.get_tensor()
lod_tensor = var_lod.get_lod_tensor()
lod_tensor.set_dims([4, 4, 6])
lod_tensor.alloc_int(place)
print lod_tensor
array = numpy.array(lod_tensor)
print "---- array ----", array
array[0, 0, 0] = 3
array[3, 3, 5] = 10
lod_tensor.set(array, place)
# lod_tensor.set_tensor(tensor)
lod_tensor.set_lod([[0, 2, 4]])
# lod_v = numpy.array(lod_tensor.tensor())
lod_v = numpy.array(lod_tensor)
self.assertTrue(numpy.alltrue(array == lod_v))
lod = lod_tensor.lod()
self.assertEqual(0, lod[0][0])
self.assertEqual(2, lod[0][1])
self.assertEqual(4, lod[0][2])
def not_test_float_lod_tensor(self):
places = [core.CPUPlace(), core.GPUPlace(0)]
for place in places:
scope = core.Scope()
var = scope.new_var("test_tensor")
var_lod = scope.new_var("test_lod_tensor")
tensor = var.get_tensor()
lod_tensor = var_lod.get_lod_tensor()
tensor.set_dims([5, 2, 3, 4])
tensor.alloc_float(place)
tensor_array = numpy.array(tensor)
self.assertEqual((5, 2, 3, 4), tensor_array.shape)
tensor_array[0, 0, 0, 0] = 1.0
tensor_array[0, 0, 0, 1] = 2.0
tensor.set(tensor_array, place)
lod_tensor.set_tensor(tensor)
lod_v = numpy.array(lod_tensor.tensor())
self.assertAlmostEqual(1.0, lod_v[0, 0, 0, 0])
self.assertAlmostEqual(2.0, lod_v[0, 0, 0, 1])
self.assertEqual(len(lod_tensor.lod()), 0)
lod_py = [[0, 2, 5], [0, 2, 4, 5]]
lod_tensor.set_lod(lod_py)
lod = lod_tensor.lod()
self.assertListEqual(lod_py, lod)
def not_test_lod_tensor_init(self):
place = core.CPUPlace()
scope = core.Scope()
var = scope.new_var("test_tensor")
var_lod = scope.new_var("test_lod_tensor")
lod_tensor = var_lod.get_tensor()
lod_tensor.set_dims([4, 4, 6])
lod_tensor.alloc_int(place)
array = numpy.array(lod_tensor)
array[0, 0, 0] = 3
array[3, 3, 5] = 10
lod_tensor.set(array, place)
lod_tensor.set_lod([[0, 2, 4]])
lod_v = numpy.array(lod_tensor)
self.assertTrue(numpy.alltrue(array == lod_v))
lod = lod_tensor.lod()
self.assertEqual(0, lod[0][0])
self.assertEqual(2, lod[0][1])
self.assertEqual(4, lod[0][2])
def test_float_lod_tensor(self):
place = core.CPUPlace()
tensor = var.get_tensor()
tensor.set_dims([5, 2, 3, 4])
tensor.alloc_float(place)
tensor_array = numpy.array(tensor)
scope = core.Scope()
var_lod = scope.new_var("test_lod_tensor")
lod_tensor = var_lod.get_tensor()
lod_tensor.set_dims([5, 2, 3, 4])
lod_tensor.alloc_float(place)
tensor_array = numpy.array(lod_tensor)
self.assertEqual((5, 2, 3, 4), tensor_array.shape)
tensor_array[0, 0, 0, 0] = 1.0
tensor_array[0, 0, 0, 1] = 2.0
tensor.set(tensor_array, place)
lod_tensor.set(tensor_array, place)
lod_v = numpy.array(lod_tensor)
self.assertAlmostEqual(1.0, lod_v[0, 0, 0, 0])
self.assertAlmostEqual(2.0, lod_v[0, 0, 0, 1])
self.assertEqual(len(lod_tensor.lod()), 0)
lod_py = [[0, 2, 5], [0, 2, 4, 5]]
lod_tensor.set_lod(lod_py)
lod = lod_tensor.lod()
self.assertListEqual(lod_py, lod)
def test_lod_tensor_init(self):
scope = core.Scope()
place = core.CPUPlace()
lod_py = [[0, 2, 5], [0, 2, 4, 5]]
lod_tensor = core.LoDTensor(lod_py)
lod_tensor.set_dims([5, 2, 3, 4])
lod_tensor.alloc_float(place)
tensor_array = numpy.array(lod_tensor)
tensor_array[0, 0, 0, 0] = 1.0
tensor_array[0, 0, 0, 1] = 2.0
lod_tensor.set(tensor_array, place)
lod_tensor = core.LoDTensor(lod_py, tensor)
lod_v = numpy.array(lod_tensor.tensor())
lod_v = numpy.array(lod_tensor)
self.assertAlmostEqual(1.0, lod_v[0, 0, 0, 0])
self.assertAlmostEqual(2.0, lod_v[0, 0, 0, 1])
self.assertListEqual(lod_py, lod_tensor.lod())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册