提交 e579eb00 编写于 作者: S superjomn

update

上级 380de6da
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/fc_op.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
TEST(fc_op_lite, test) {
LOG(INFO) << "\n" << KernelRegistry::Global().DebugString();
// prepare variables
Scope scope;
auto* x = scope.Var("x")->GetMutable<Tensor>();
auto* w = scope.Var("w")->GetMutable<Tensor>();
auto* bias = scope.Var("bias")->GetMutable<Tensor>();
auto* output = scope.Var("output")->GetMutable<Tensor>();
x->Resize({1, 10, 20});
w->Resize({20, 20});
bias->Resize({1, 10});
output->Resize({10, 20});
// set data
for (int i = 0; i < 10 * 20; i++) {
x->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 20 * 20; i++) {
w->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 1 * 10; i++) {
bias->mutable_data<float>()[i] = i;
}
for (int i = 0; i < 10 * 20; i++) {
output->mutable_data<float>()[i] = 0.;
}
// prepare op desc
framework::OpDesc desc;
desc.SetType("fc");
desc.SetInput("Input", {"x"});
desc.SetInput("W", {"w"});
desc.SetInput("Bias", {"bias"});
desc.SetOutput("Out", {"output"});
desc.SetAttr("in_num_col_dims", static_cast<int>(1));
FcOpLite fc("fc");
fc.SetValidPlaces({OpLite::Place{TARGET(kHost), PRECISION(kFloat)}});
fc.PickKernel({OpLite::Place{TARGET(kHost), PRECISION(kFloat)}});
fc.Attach(desc, &scope);
fc.Run();
for (int i = 0; i < 10 * 20; i++) {
LOG(INFO) << output->data<float>()[i];
}
}
} // namespace operators
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(fc, kHost, kFloat);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/op_params.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/utils/all.h"
/*
* This file contains all the argument parameter data structure for operators.
*/
namespace paddle {
namespace lite {
namespace operators {
struct FcParam {
Tensor* input{nullptr};
Tensor* w{};
Tensor* bias{};
Tensor* output{};
DDim in_mat_dims;
int in_num_col_dims{0};
};
using param_t = variant<FcParam>;
} // namespace operators
} // namespace lite
} // namespace paddle
cc_test(test_varient SRCS varient_test.cc)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/utils/varient.h"
#include <gtest/gtest.h>
#include <set>
#include <string>
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace utils {
TEST(varient, test) {
variant<int, float> a;
a.set<int>(1);
ASSERT_EQ(a.get<int>(), 1);
a.set<int>(20);
ASSERT_EQ(a.get<int>(), 20);
}
TEST(varient, reference) {
variant<int, float, std::string> a;
a.set<std::string>("hello world");
auto& b = a.get<std::string>();
ASSERT_EQ(b, "hello world");
}
TEST(varient, get_wrong_type) {
variant<int, float> a;
a.set<int>(100);
bool exception = false;
try {
float b = a.get<float>();
LOG(INFO) << b + 1;
} catch (...) {
exception = true;
}
ASSERT_TRUE(exception);
}
} // namespace utils
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册