提交 8a4b7663 编写于 作者: W wangliu

fix #305

上级 b1a7eddf
......@@ -58,6 +58,7 @@ class OperatorBase : PaddleMobileObject {
std::shared_ptr<Scope> scope);
virtual ~OperatorBase() {}
virtual void Run() const = 0;
virtual void InferShape() const = 0;
const VariableNameMap &Inputs() const { return inputs_; }
const VariableNameMap &Outputs() const { return outputs_; }
......@@ -87,8 +88,8 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
const VariableNameMap &outputs, const AttributeMap &attrs,
std::shared_ptr<Scope> scope)
: OperatorBase<Dtype>(type, inputs, outputs, attrs, scope) {}
virtual void InferShape() const = 0;
virtual void Run() const = 0;
virtual void InferShape() const = 0;
};
template <typename Dtype, typename P>
......
......@@ -34,19 +34,10 @@ Variable *Scope::Var(const std::string &name) {
}
pvar = new Variable;
vars_[name] = pvar;
pvar->name_ = &(vars_.find(name)->first);
pvar->name_ = vars_.find(name)->first;
return pvar;
}
// Variable* Scope::Var(std::string* name) {
// auto var_name = string::Sprintf("%p.%d", this,
// vars_.size());
// if (name != nullptr) {
// *name = var_name;
// }
// return Var(var_name);
// }
Variable *Scope::FindVar(const std::string &name) const {
auto *pvar = FindVarLocally(name);
if (pvar != nullptr) {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <common/enforce.h>
#include <cstdint>
#include <cstring>
#include <memory>
......@@ -217,18 +218,14 @@ class Tensor {
}
inline void check_memory_size() const {
// PADDLE_ENFORCE_NOT_NULL(
// holder_, "Tensor holds no memory. Call
// Tensor::mutable_data
// first.");
// PADDLE_ENFORCE_LE(
// numel() * SizeOfType(type()), memory_size(),
// "Tensor's dims_ is out of bound. Call
// Tensor::mutable_data "
// "first to re-allocate memory.\n"
// "or maybe the required data-type mismatches the data
// already
// stored.");
PADDLE_MOBILE_ENFORCE(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
PADDLE_MOBILE_ENFORCE(
numel() * SizeOfType(type()) <= memory_size(),
"Tensor's dims_ is out of bound. CallTensor::mutable_data "
"first to re-allocate memory.\n"
"or maybe the required data-type mismatches the data\
already stored.");
}
inline DataLayout layout() const { return layout_; }
......
......@@ -19,10 +19,13 @@ limitations under the License. */
#include <string>
#include <typeindex>
#include <typeinfo>
#include "../common/variant.h"
#include "paddle_mobile_object.h"
namespace paddle_mobile {
namespace framework {
using std::string;
class Variable : public PaddleMobileObject {
public:
template <typename T>
......@@ -30,17 +33,23 @@ class Variable : public PaddleMobileObject {
return static_cast<const T *>(holder_->Ptr());
}
template <typename T>
const T GetValue() const {
return variant.Get<T>();
}
template <typename T>
void SetValue(T value) {
variant.Set<T>(value);
}
bool IsInitialized() const { return holder_ != nullptr; }
const std::string *Name() { return name_; }
const std::string Name() { return name_; }
template <typename T>
T *GetMutable() {
if (!IsType<T>()) {
if (*Name() == "pixel") {
// std::cout << " reset " << *Name() <<
// std::endl;
}
holder_.reset(new PlaceholderImp<T>(new T()));
}
return static_cast<T *>(holder_->Ptr());
......@@ -48,15 +57,6 @@ class Variable : public PaddleMobileObject {
template <typename T>
bool IsType() const {
if (holder_) {
// printf("not null \n");
printf(" holder type : %s, this type %s \n", holder_->Type().name(),
typeid(T).name());
}
// std::cout << " " << holder_->Type() << " " <<
// typeid(T) <<
// std::endl;
return holder_ != nullptr && holder_->Type() == typeid(T);
}
......@@ -64,7 +64,7 @@ class Variable : public PaddleMobileObject {
std::type_index Type() const { return holder_->Type(); }
void SetName(const std::string *name) { name_ = name; }
void SetName(const string name) { name_ = name; }
private:
struct Placeholder {
......@@ -87,10 +87,10 @@ class Variable : public PaddleMobileObject {
std::unique_ptr<T> ptr_;
const std::type_info &type_;
};
Variant<int, bool, string, float, double> variant;
std::unique_ptr<Placeholder> holder_;
friend class Scope;
const std::string *name_;
string name_;
};
} // namespace framework
} // namespace paddle_mobile
......@@ -27,6 +27,7 @@ limitations under the License. */
#include "framework/tensor.h"
namespace paddle_mobile {
using framework::Variable;
void ReadBinaryFile(const std::string &filename, std::string *contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
......@@ -204,10 +205,12 @@ const framework::Program<Dtype, P> Loader<Dtype, P>::Load(
var_desc->Type() != framework::VARTYPE_TYPE_FEED_MINIBATCH &&
var_desc->Type() != framework::VARTYPE_TYPE_FETCH_LIST) {
// DLOG << "to load var ";
LoadVar(var, *var_desc, dirname + "/" + var_desc->Name());
auto dim = var_desc->Tensor_desc().Dims();
auto tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim(dim));
} else {
auto dim = var_desc->Tensor_desc().Dims();
PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0");
PADDLE_MOBILE_ENFORCE(dim.size() > 1, "dim size is 0");
dim[0] = 1;
auto tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize(framework::make_ddim(dim));
......@@ -243,11 +246,39 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p) : program_(p) {
std::vector<std::shared_ptr<framework::OpDesc>> ops = block_desc->Ops();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<framework::OpDesc> op = ops[j];
// auto op_base =
// framework::OpRegistry<Dtype>::CreateOp(op->Type(),
// op->GetInputs(), op->GetOutputs(),
// op->GetAttrMap(), program_.scope);
// op_base->InferShape();
auto op_base = framework::OpRegistry<Dtype>::CreateOp(
op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(),
program_.scope);
op_base->InferShape();
ops_of_block_[*block_desc.get()].push_back(op_base);
}
}
InitMemory();
}
template <typename Dtype, Precision P>
Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size)
: program_(p), batch_size_(batch_size) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
to_predict_program_ = program_.originProgram;
}
Variable *variable_ptr = program_.scope->Var("batch_size");
variable_ptr[0].SetValue<int>(batch_size);
const std::vector<std::shared_ptr<framework::BlockDesc>> blocks =
to_predict_program_->Blocks();
for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<framework::BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<framework::OpDesc>> ops = block_desc->Ops();
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<framework::OpDesc> op = ops[j];
auto op_base = framework::OpRegistry<Dtype>::CreateOp(
op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(),
program_.scope);
op_base->InferShape();
ops_of_block_[*block_desc.get()].push_back(op_base);
}
}
InitMemory();
......@@ -342,6 +373,9 @@ void Executor<Dtype, P>::InitMemory() {
auto var = program_.scope->Var(var_desc->Name());
if (var_desc->Persistable()) {
auto tensor = var->template GetMutable<framework::LoDTensor>();
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
continue;
}
LoadMemory(*var_desc, tensor,
program_.model_path + "/" + var_desc->Name());
} else {
......@@ -381,9 +415,11 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::predict(
template <typename Dtype, Precision P>
void Executor<Dtype, P>::predict(const framework::Tensor &t, int block_id) {
// framework::Variable *g_feed_value = program_.scope->Var("feed");
// auto feed_tensor = g_feed_value->GetMutable<framework::Tensor>();
// feed_tensor->ShareDataWith(t);
framework::Variable *g_feed_value = program_.scope->Var("feed");
auto feed_tensor = g_feed_value->GetMutable<framework::LoDTensor>();
feed_tensor->Resize(t.dims());
feed_tensor->ShareDataWith(t);
std::shared_ptr<framework::BlockDesc> to_predict_block =
to_predict_program_->Block(block_id);
......
......@@ -47,6 +47,8 @@ class Executor {
Executor(const framework::Program<Dtype> p);
Executor(const framework::Program<Dtype> p, int batch_size);
std::shared_ptr<framework::Tensor> predict(framework::Tensor &t);
std::vector<Ptype> predict(const std::vector<Ptype> &input,
......@@ -57,6 +59,7 @@ class Executor {
void LoadMemory(const framework::VarDesc var_desc,
framework::LoDTensor *tensor, const std::string &file_path);
framework::Program<Dtype> program_;
int batch_size_ = 1;
std::shared_ptr<framework::ProgramDesc> to_predict_program_;
void predict(const framework::Tensor &t, int block_id);
std::map<framework::BlockDesc,
......
......@@ -56,6 +56,11 @@ void ConvOp<Dtype, T>::InferShape() const {
std::vector<int> dilations = param_.Dilations();
PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() &&
dilations.size() == paddings.size() &&
paddings.size() == strides.size()),
"ConvParam is not suitable");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2],
......
......@@ -13,3 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "feed_op.h"
namespace paddle_mobile {
namespace operators {
template class FeedOp<CPU, float>;
}
} // namespace paddle_mobile
......@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class FeedOp : framework::OperatorBase<DeviceType> {
class FeedOp : public framework::OperatorBase<DeviceType> {
public:
FeedOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap attrs,
......@@ -32,8 +32,9 @@ class FeedOp : framework::OperatorBase<DeviceType> {
void Run() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void InferShape() const {
auto x_dims = param_.InputX()->dims();
param_.Out()->Resize(x_dims);
auto out_dims = param_.Out()->dims();
out_dims[0] = param_.BatchSize();
param_.Out()->Resize(out_dims);
}
protected:
......@@ -41,8 +42,8 @@ class FeedOp : framework::OperatorBase<DeviceType> {
};
namespace ops = paddle_mobile::operators;
// USE_OP(Feed);
// REGISTER_OPERATOR(Feed, ops::FeedOp);
USE_OP(feed);
REGISTER_OPERATOR(feed, ops::FeedOp);
} // namespace operators
} // namespace paddle_mobile
......@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
//
// Created by liuRuiLong on 2018/5/25.
//
#include "fetch_op.h"
namespace paddle_mobile {
namespace operators {
template class FetchOp<CPU, float>;
}
} // namespace paddle_mobile
......@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class FetchOp : framework::OperatorBase<DeviceType> {
class FetchOp : public framework::OperatorBase<DeviceType> {
public:
FetchOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap attrs,
......@@ -29,7 +29,12 @@ class FetchOp : framework::OperatorBase<DeviceType> {
: framework::OperatorBase<DeviceType>(type, inputs, outputs, attrs,
scope),
param_(inputs, outputs, attrs, *scope) {}
void Run() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Run() const {
param_.Out()->ShareDataWith(*param_.InputX());
for (int i = 0; i < param_.Out()->numel(); ++i) {
DLOG << param_.Out()->template data<float>()[i];
}
}
void InferShape() const {
auto x_dims = param_.InputX()->dims();
......@@ -41,8 +46,8 @@ class FetchOp : framework::OperatorBase<DeviceType> {
};
namespace ops = paddle_mobile::operators;
// USE_OP(Fetch);
// REGISTER_OPERATOR(Fetch, ops::FetchOp);
USE_OP(fetch);
REGISTER_OPERATOR(fetch, ops::FetchOp);
} // namespace operators
} // namespace paddle_mobile
......@@ -197,8 +197,8 @@ class ConvParam : OpParam {
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
filter_ = FilterFrom<LoDTensor>(inputs, scope);
input_ = InputFrom<Tensor>(inputs, scope);
output_ = OutputFrom<Tensor>(outputs, scope);
input_ = InputFrom<LoDTensor>(inputs, scope);
output_ = OutputFrom<LoDTensor>(outputs, scope);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
dilations_ = GetAttr<vector<int>>("dilations", attrs);
......@@ -237,9 +237,9 @@ class ElementwiseAddParam : OpParam {
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
input_y_ = InputYFrom<framework::Tensor>(inputs, scope);
out_ = OutFrom<framework::Tensor>(outputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope);
input_y_ = InputYFrom<framework::LoDTensor>(inputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope);
axis_ = GetAttr<int>("axis", attrs);
}
......@@ -263,9 +263,9 @@ class MulParam : OpParam {
MulParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
input_y_ = InputYFrom<framework::Tensor>(inputs, scope);
out_ = OutFrom<framework::Tensor>(outputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope);
input_y_ = InputYFrom<framework::LoDTensor>(inputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope);
x_num_col_dims_ = GetAttr<int>("x_num_col_dims", attrs);
y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs);
}
......@@ -293,19 +293,19 @@ class ConcatParam : public OpParam {
ConcatParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
inputs_ = InputMultiFrom<framework::Tensor>(inputs, scope);
out_ = OutFrom<framework::Tensor>(outputs, scope);
inputs_ = InputMultiFrom<LoDTensor>(inputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope);
axis_ = GetAttr<int>("axis", attrs);
}
vector<Tensor *> Inputs() const { return inputs_; }
vector<LoDTensor *> Inputs() const { return inputs_; }
Tensor *Out() const { return out_; }
const int &Axis() const { return axis_; }
private:
vector<Tensor *> inputs_;
vector<LoDTensor *> inputs_;
Tensor *out_;
int axis_;
};
......@@ -315,9 +315,9 @@ class LrnParam : public OpParam {
LrnParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
out_ = OutFrom<framework::Tensor>(outputs, scope);
mid_out_ = MidOutFrom<framework::Tensor>(outputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope);
mid_out_ = MidOutFrom<framework::LoDTensor>(outputs, scope);
n_ = GetAttr<int>("n", attrs);
alpha_ = GetAttr<float>("alpha", attrs);
beta_ = GetAttr<float>("beta", attrs);
......@@ -356,12 +356,12 @@ class BatchNormParam : OpParam {
BatchNormParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
output_y_ = OutputYFrom<framework::Tensor>(outputs, scope);
input_bias_ = InputBiasFrom<framework::Tensor>(inputs, scope);
input_mean_ = InputMeanFrom<framework::Tensor>(inputs, scope);
input_scale_ = InputScaleFrom<framework::Tensor>(inputs, scope);
input_variance_ = InputVarianceFrom<framework::Tensor>(inputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope);
output_y_ = OutputYFrom<framework::LoDTensor>(outputs, scope);
input_bias_ = InputBiasFrom<framework::LoDTensor>(inputs, scope);
input_mean_ = InputMeanFrom<framework::LoDTensor>(inputs, scope);
input_scale_ = InputScaleFrom<framework::LoDTensor>(inputs, scope);
input_variance_ = InputVarianceFrom<framework::LoDTensor>(inputs, scope);
epsilon_ = GetAttr<float>("epsilon", attrs);
momentum_ = GetAttr<float>("momentum", attrs);
is_test_ = GetAttr<bool>("is_test", attrs);
......@@ -404,9 +404,9 @@ class PoolParam : public OpParam {
PoolParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_ = InputXFrom<framework::Tensor>(inputs, scope);
input_ = InputXFrom<framework::LoDTensor>(inputs, scope);
output_ = OutFrom<framework::Tensor>(outputs, scope);
output_ = OutFrom<framework::LoDTensor>(outputs, scope);
pooling_type_ = GetAttr<string>("pooling_type", attrs);
ksize_ = GetAttr<vector<int>>("ksize", attrs);
strides_ = GetAttr<vector<int>>("strides", attrs);
......@@ -447,10 +447,11 @@ class PriorBoxParam : public OpParam {
PriorBoxParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_ = InputFrom<framework::Tensor>(inputs, scope);
input_image_ = InputImageFrom<framework::Tensor>(inputs, scope);
output_boxes_ = OutputBoxesFrom<framework::Tensor>(outputs, scope);
output_variances_ = OutputVariancesFrom<framework::Tensor>(outputs, scope);
input_ = InputFrom<framework::LoDTensor>(inputs, scope);
input_image_ = InputImageFrom<framework::LoDTensor>(inputs, scope);
output_boxes_ = OutputBoxesFrom<framework::LoDTensor>(outputs, scope);
output_variances_ =
OutputVariancesFrom<framework::LoDTensor>(outputs, scope);
min_sizes_ = GetAttr<vector<float>>("min_sizes", attrs);
max_sizes_ = GetAttr<vector<float>>("max_sizes", attrs);
aspect_ratios_ = GetAttr<vector<float>>("aspect_ratios", attrs);
......@@ -508,10 +509,11 @@ class BoxCoderParam : public OpParam {
BoxCoderParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_priorbox_ = InputPriorBoxFrom<framework::Tensor>(inputs, scope);
input_priorboxvar_ = InputPriorBoxVarFrom<framework::Tensor>(inputs, scope);
input_targetbox_ = InputTargetBoxFrom<framework::Tensor>(inputs, scope);
output_box_ = OutputBoxFrom<framework::Tensor>(outputs, scope);
input_priorbox_ = InputPriorBoxFrom<framework::LoDTensor>(inputs, scope);
input_priorboxvar_ =
InputPriorBoxVarFrom<framework::LoDTensor>(inputs, scope);
input_targetbox_ = InputTargetBoxFrom<framework::LoDTensor>(inputs, scope);
output_box_ = OutputBoxFrom<framework::LoDTensor>(outputs, scope);
code_type_ = GetAttr<std::string>("code_type", attrs);
}
const Tensor *InputPriorBox() const { return input_priorbox_; }
......@@ -537,8 +539,8 @@ class SoftmaxParam : public OpParam {
SoftmaxParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
out_ = OutFrom<framework::Tensor>(outputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope);
}
const Tensor *InputX() const { return input_x_; }
Tensor *Out() const { return out_; }
......@@ -553,8 +555,8 @@ class SigmoidParam : public OpParam {
SigmoidParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
out_ = OutFrom<framework::Tensor>(outputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope);
}
const Tensor *InputX() const { return input_x_; }
Tensor *Out() const { return out_; }
......@@ -568,9 +570,9 @@ class MultiClassNMSParam : public OpParam {
MultiClassNMSParam(const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
const Scope &scope) {
input_bboxes_ = InputBBoxesFrom<Tensor>(inputs, scope);
input_scores_ = InputScoresFrom<Tensor>(inputs, scope);
out_ = OutFrom<Tensor>(outputs, scope);
input_bboxes_ = InputBBoxesFrom<LoDTensor>(inputs, scope);
input_scores_ = InputScoresFrom<LoDTensor>(inputs, scope);
out_ = OutFrom<LoDTensor>(outputs, scope);
background_label_ = GetAttr<int>("background_label", attrs);
nms_top_k_ = GetAttr<int>("nms_top_k", attrs);
keep_top_k_ = GetAttr<int>("keep_top_k", attrs);
......@@ -612,17 +614,20 @@ class MultiClassNMSParam : public OpParam {
class FeedParam : public OpParam {
public:
FeedParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
out_ = OutFrom<framework::Tensor>(outputs, scope);
const framework::AttributeMap &attrs, framework::Scope &scope) {
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope);
auto var = scope.Var("batch_size");
batch_size = var->GetValue<int>();
}
const Tensor *InputX() const { return input_x_; }
Tensor *Out() const { return out_; }
const int BatchSize() const { return batch_size; }
private:
Tensor *input_x_;
Tensor *out_;
int batch_size;
};
class FetchParam : public OpParam {
......@@ -630,8 +635,8 @@ class FetchParam : public OpParam {
FetchParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
const framework::Scope &scope) {
input_x_ = InputXFrom<framework::Tensor>(inputs, scope);
out_ = OutFrom<framework::Tensor>(outputs, scope);
input_x_ = InputXFrom<framework::LoDTensor>(inputs, scope);
out_ = OutFrom<framework::LoDTensor>(outputs, scope);
}
const Tensor *InputX() const { return input_x_; }
Tensor *Out() const { return out_; }
......@@ -645,8 +650,8 @@ class TransposeParam : public OpParam {
public:
TransposeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<Tensor>(inputs, scope);
out_ = OutFrom<Tensor>(outputs, scope);
input_x_ = InputXFrom<LoDTensor>(inputs, scope);
out_ = OutFrom<LoDTensor>(outputs, scope);
axis_ = GetAttr<vector<int>>("axis", attrs);
}
......@@ -666,9 +671,9 @@ class ReshapeParam : public OpParam {
public:
ReshapeParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<Tensor>(inputs, scope);
input_shape_ = InputShapeFrom<Tensor>(inputs, scope);
out_ = OutFrom<Tensor>(outputs, scope);
input_x_ = InputXFrom<LoDTensor>(inputs, scope);
input_shape_ = InputShapeFrom<LoDTensor>(inputs, scope);
out_ = OutFrom<LoDTensor>(outputs, scope);
shape_ = GetAttr<vector<int>>("shape", attrs);
inplace_ = GetAttr<bool>("inplace", attrs);
}
......@@ -695,8 +700,8 @@ class ReluParam : public OpParam {
public:
ReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<Tensor>(inputs, scope);
out_ = OutFrom<Tensor>(outputs, scope);
input_x_ = InputXFrom<LoDTensor>(inputs, scope);
out_ = OutFrom<LoDTensor>(outputs, scope);
}
const Tensor *InputX() const { return input_x_; }
......
......@@ -49,7 +49,6 @@ void PoolOp<DeviceType, T>::InferShape() const {
paddings[i], strides[i], ceil_mode));
}
param_.Output()->Resize(framework::make_ddim(output_shape));
DLOG << "infer shape out size =" << param_.Output()->numel();
}
template class PoolOp<CPU, float>;
} // namespace operators
......
......@@ -24,7 +24,7 @@ int main() {
// ../../../test/models/mobilenet
auto program = loader.Load(std::string("../models/googlenet"));
paddle_mobile::Executor<paddle_mobile::CPU> executor(program);
paddle_mobile::Executor<paddle_mobile::CPU> executor(program, 1);
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册