未验证 提交 95e0497f 编写于 作者: S suiyang 提交者: GitHub

Merge pull request #1214 from Eclipsess/develop

fix #1213 add test-eng
...@@ -29,10 +29,9 @@ void FusionFcCompute(const FusionFcParam<CPU> &param) { ...@@ -29,10 +29,9 @@ void FusionFcCompute(const FusionFcParam<CPU> &param) {
auto *input_z_data = input_z->data<float>(); auto *input_z_data = input_z->data<float>();
int axis = param.Axis(); int axis = param.Axis();
Tensor *out = param.Out(); Tensor *out = param.Out();
auto *out_data = out->mutable_data<float>();
// int m = out->dims()[0]; // int m = out->dims()[0];
// int n = out->dims()[1]; // int n = out->dims()[1];
auto *out_data = out->mutable_data<float>();
const Tensor x_matrix = const Tensor x_matrix =
input_x->dims().size() > 2 input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) ? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
......
...@@ -83,6 +83,7 @@ void PoolCompute(const PoolParam<CPU> &param) { ...@@ -83,6 +83,7 @@ void PoolCompute(const PoolParam<CPU> &param) {
#if __aarch64__ #if __aarch64__
PoolBasic(pooling_type, ksize, strides, paddings, in_x, out); PoolBasic(pooling_type, ksize, strides, paddings, in_x, out);
#else #else
/// todo: fix bug in Pool2x2
if (pooling_type == "max") { if (pooling_type == "max") {
math::Pool2x2Maxs2p0(strides, paddings, in_x, out); math::Pool2x2Maxs2p0(strides, paddings, in_x, out);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
......
...@@ -24,6 +24,7 @@ void SoftmaxCompute(const SoftmaxParam<CPU> &param) { ...@@ -24,6 +24,7 @@ void SoftmaxCompute(const SoftmaxParam<CPU> &param) {
Tensor *out = param.Out(); Tensor *out = param.Out();
auto x_dims = in_x->dims(); auto x_dims = in_x->dims();
out->Resize(x_dims); out->Resize(x_dims);
out->mutable_data<float>();
math::SoftmaxFuntor<CPU, float>()(in_x, out); math::SoftmaxFuntor<CPU, float>()(in_x, out);
} }
} // namespace operators } // namespace operators
......
...@@ -2147,9 +2147,9 @@ class Im2SequenceParam : public OpParam { ...@@ -2147,9 +2147,9 @@ class Im2SequenceParam : public OpParam {
paddings_ = GetAttr<vector<int>>("paddings", attrs); paddings_ = GetAttr<vector<int>>("paddings", attrs);
} }
const RType *Input() const { return input_x_; } const GType *Input() const { return input_x_; }
RType *Output() const { return out_; } GType *Output() const { return out_; }
const vector<int> &Kernels() const { return kernels_; } const vector<int> &Kernels() const { return kernels_; }
...@@ -2158,8 +2158,8 @@ class Im2SequenceParam : public OpParam { ...@@ -2158,8 +2158,8 @@ class Im2SequenceParam : public OpParam {
const vector<int> &Paddings() const { return paddings_; } const vector<int> &Paddings() const { return paddings_; }
private: private:
RType *input_x_; GType *input_x_;
RType *out_; GType *out_;
vector<int> kernels_; vector<int> kernels_;
vector<int> strides_; vector<int> strides_;
vector<int> paddings_; vector<int> paddings_;
......
...@@ -23,13 +23,13 @@ int main() { ...@@ -23,13 +23,13 @@ int main() {
// paddle_mobile.SetThreadNum(4); // paddle_mobile.SetThreadNum(4);
auto time1 = time(); auto time1 = time();
if (paddle_mobile.Load(std::string(g_eng) + "/model", if (paddle_mobile.Load(std::string(g_eng) + "/model",
std::string(g_eng) + "/params", false, false, 1, std::string(g_eng) + "/params", true, false, 1,
true)) { true)) {
auto time2 = time(); auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl; std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl;
std::vector<int64_t> dims{1, 1, 48, 512}; std::vector<int64_t> dims{1, 1, 48, 400};
LoDTensor input_tensor; LoDTensor input_tensor;
SetupTensor<float>(&input_tensor, {1, 1, 48, 512}, static_cast<float>(0), SetupTensor<float>(&input_tensor, {1, 1, 48, 400}, static_cast<float>(0),
static_cast<float>(1)); static_cast<float>(1));
std::vector<float> input(input_tensor.data<float>(), std::vector<float> input(input_tensor.data<float>(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册