未验证 提交 362275ed 编写于 作者: 石晓伟 提交者: GitHub

Add kernel version table and update framework.proto, test=develop (#2243)

* update framework.proto

* add compatibility check, test=develop

* remove head files, test=develop
上级 06d7a8f5
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
// option optimize_for = LITE_RUNTIME;
option optimize_for = LITE_RUNTIME;
package paddle.framework.proto;
// Any incompatible changes to ProgramDesc and its dependencies should
......@@ -166,6 +166,9 @@ message VarDesc {
required string name = 1;
required VarType type = 2;
optional bool persistable = 3 [ default = false ];
// True if the variable is an input data and
// have to check the feed data shape and dtype
optional bool need_check_feed = 4 [ default = false ];
}
message BlockDesc {
......@@ -176,13 +179,39 @@ message BlockDesc {
optional int32 forward_block_idx = 5 [ default = -1 ];
}
// CompatibleInfo is used to determine if a feature is compatible and
// provides the information.
message CompatibleInfo {
enum Type {
COMPATIBLE = 0;
DEFINITELY_NOT = 1;
POSSIBLE = 2;
BUG_FIX = 3;
PRECISION_CHANGE = 4;
}
required string version = 1;
required Type type = 2;
}
// In some cases, Paddle Fluid may perform operator definition iterations,
// and the operator uses OpCompatibleMap for compatibility testing.
message OpCompatibleMap {
message OpCompatiblePair {
required string op_name = 1;
required CompatibleInfo compatible_info = 2;
}
repeated OpCompatiblePair pair = 1;
optional string default_required_version = 2;
}
// Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md
// for more details.
// TODO(panyx0718): A model can have multiple programs. Need a
// way to distinguish them. Maybe ID or name?
message ProgramDesc {
reserved 2; // For backward compatibility.
repeated BlockDesc blocks = 1;
optional Version version = 2;
optional Version version = 4;
optional OpCompatibleMap op_compatible_map = 3;
}
......@@ -27,6 +27,7 @@
#include <unordered_set>
#include <vector>
#include "lite/core/tensor.h"
#include "lite/core/version.h"
#include "lite/utils/all.h"
namespace paddle {
......@@ -280,7 +281,7 @@ struct ParamTypeRecorder {
*/
class ParamTypeRegistry {
public:
enum class IO : int { kInput = 0, kOutput };
enum class IO : int { kInvalid = 0, kInput, kOutput };
template <TargetType target,
PrecisionType precision,
......@@ -310,6 +311,12 @@ class ParamTypeRegistry {
kernel_type_, Place{target, precision, layout}, arg_name, ptype);
return *this;
}
NewInstance& SetVersion(const std::string& version) {
ParamTypeRegistry::Global().SetVersion(int_version(version),
Split(kernel_type_, "/").front(),
Place{target, precision, layout});
return *this;
}
bool Finalize() { return true; }
......@@ -327,6 +334,22 @@ class ParamTypeRegistry {
CHECK(types_.count(key));
}
void SetVersion(const int64_t version,
const std::string& kernel_type,
const Place& place) {
KernelIdTy key{kernel_type, place, IO(), std::string()};
versions_[key] = version;
CHECK(versions_.count(key));
}
int64_t GetVersion(const std::string& kernel_type, const Place& place) {
KernelIdTy key{kernel_type, place, IO(), std::string()};
if (versions_.count(key)) {
return versions_[key];
}
return -1;
}
const ParamType* RetrieveInArgument(const Place& place,
const std::string& op_type,
const std::string& arg_name) {
......@@ -384,6 +407,7 @@ class ParamTypeRegistry {
private:
std::map<key_t, ParamType, ParamTypeRegistry::KeyCmp> types_;
std::map<key_t, int64_t, ParamTypeRegistry::KeyCmp> versions_;
};
} // namespace lite
......
......@@ -16,10 +16,15 @@
#include <string>
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
namespace paddle {
namespace lite {
static constexpr int MAJOR_COEFF = 1000000;
static constexpr int MINOR_COEFF = 1000;
static constexpr int PATCH_COEFF = 1;
static std::string paddlelite_commit() {
return "@PADDLE_LITE_COMMIT@";
}
......@@ -45,5 +50,15 @@ static std::string version() {
return ss.str();
}
static int64_t int_version(const std::string& version) {
const std::vector<std::string> vec = Split(version, ".");
if (vec.size() == 3) {
return std::stoi(vec[0]) * MAJOR_COEFF +
std::stoi(vec[1]) * MINOR_COEFF +
std::stoi(vec[2]) * PATCH_COEFF;
}
return -1;
}
} // namespace lite
} // namespace paddle
......@@ -33,7 +33,6 @@ nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpos
nv_test(concat_compute_cuda_test SRCS concat_compute_test.cc DEPS concat_compute_cuda)
nv_test(elementwise_add_compute_cuda_test SRCS elementwise_add_compute_test.cc DEPS elementwise_add_compute_cuda)
nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_compute_cuda)
nv_test(pool_compute_cuda_test SRCS pool_compute_test.cc DEPS pool_compute_cuda)
#nv_test(layout_cuda_test SRCS layout_compute_test.cc DEPS layout_compute_cuda)
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
......
......@@ -66,4 +66,5 @@ REGISTER_LITE_KERNEL(leaky_relu,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.SetVersion("1.5.0")
.Finalize();
......@@ -29,5 +29,15 @@ lite_cc_library(model_parser SRCS model_parser.cc DEPS
compatible_pb
memory
CUDA_DEPS target_wrapper_cuda)
lite_cc_test(test_compatible_pb SRCS compatible_pb_test.cc DEPS compatible_pb)
if (LITE_WITH_CUDA AND NOT LITE_ON_TINY_PUBLISH)
lite_cc_library(compatibility SRCS compatibility.cc DEPS
kernel
variable
compatible_pb
type_system
${cpp_wrapper}
${naive_wrapper})
lite_cc_test(test_compatibility SRCS compatibility_test.cc DEPS compatibility leaky_relu_compute_cuda)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/compatibility.h"
#include "lite/core/type_system.h"
#include "lite/model_parser/naive_buffer/block_desc.h"
#include "lite/model_parser/naive_buffer/op_desc.h"
#include "lite/model_parser/naive_buffer/program_desc.h"
#include "lite/model_parser/naive_buffer/var_desc.h"
#ifndef LITE_ON_TINY_PUBLISH
#include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/cpp/op_desc.h"
#include "lite/model_parser/cpp/program_desc.h"
#include "lite/model_parser/cpp/var_desc.h"
#endif
namespace paddle {
namespace lite {
template <typename T>
bool CompatibleChecker<T>::CheckKernelVersion(const std::string& type,
const lite_api::Place& place) {
int64_t impl_version = ParamTypeRegistry::Global().GetVersion(type, place);
const int64_t prog_version = program_.Version();
VLOG(3) << "Kernel implement version: " << type << ", " << impl_version;
VLOG(3) << "Kernel program version: " << type << ", " << prog_version;
if (impl_version == -1) {
impl_version = mini_version_;
}
return prog_version <= impl_version;
}
template <typename T>
std::unordered_set<std::string> CompatibleChecker<T>::OpsType(T* program) {
LOG(WARNING) << "OpsType() is not yet implemented.";
return std::unordered_set<std::string>();
}
#ifndef LITE_ON_TINY_PUBLISH
template <>
std::unordered_set<std::string> CompatibleChecker<cpp::ProgramDesc>::OpsType(
cpp::ProgramDesc* program) {
std::unordered_set<std::string> ops_type;
for (size_t i = 0; i < program->BlocksSize(); ++i) {
auto* block = program->GetBlock<cpp::BlockDesc>(i);
for (size_t j = 0; j < block->OpsSize(); ++j) {
auto* op = block->GetOp<cpp::OpDesc>(j);
ops_type.insert(op->Type());
}
}
return ops_type;
}
template class CompatibleChecker<cpp::ProgramDesc>;
#endif // LITE_ON_TINY_PUBLISH
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_set>
#include "lite/api/paddle_place.h"
#include "lite/model_parser/desc_apis.h"
namespace paddle {
namespace lite {
template <typename T>
class CompatibleChecker {
public:
explicit CompatibleChecker(const T& program,
const int64_t mini_version = 1005000)
: program_(program), mini_version_(mini_version) {}
bool operator()(const lite_api::Place& place) {
bool status = true;
const std::unordered_set<std::string>& ops_type = OpsType(&program_);
if (ops_type.empty()) {
VLOG(3) << "You are checking the compatibility of an empty program.";
}
for (const auto& type : ops_type) {
bool ret = CheckKernelVersion(type, place);
VLOG(3) << "Kernel version is supported: " << type << ", " << ret;
status = status && ret;
}
return status;
}
private:
std::unordered_set<std::string> OpsType(T* program);
bool CheckKernelVersion(const std::string& type,
const lite_api::Place& place);
T program_;
int64_t mini_version_;
};
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/compatibility.h"
#include <gtest/gtest.h>
#include "lite/api/paddle_lite_factory_helper.h"
#include "lite/model_parser/compatible_pb.h"
#include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/cpp/op_desc.h"
#include "lite/model_parser/cpp/program_desc.h"
#include "lite/model_parser/cpp/var_desc.h"
USE_LITE_KERNEL(leaky_relu, kCUDA, kFloat, kNCHW, def);
namespace paddle {
namespace lite {
static constexpr int64_t version = 1005000;
TEST(CompatibleChecker, CppProgramDesc) {
cpp::ProgramDesc program;
program.SetVersion(version);
auto* block = program.AddBlock<cpp::BlockDesc>();
auto* op = block->AddOp<cpp::OpDesc>();
op->SetType("leaky_relu");
CompatibleChecker<cpp::ProgramDesc> checker(program);
lite_api::Place place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)};
CHECK(checker(place));
}
} // namespace lite
} // namespace paddle
......@@ -143,8 +143,6 @@ class OpDesc : public OpDescAPI {
template <typename T>
T GetAttr(const std::string &name) const;
std::string DebugString() const { return desc_->DebugString(); }
private:
std::vector<std::string> GetArguments(
const google::protobuf::RepeatedPtrField<framework::proto::OpDesc_Var>
......
......@@ -292,7 +292,7 @@ class RegisterLiteKernelParser(SyntaxParser):
self.eat_point()
self.eat_spaces()
self.eat_word()
assert self.token in ('BindInput', 'BindOutput', 'Finalize')
assert self.token in ('BindInput', 'BindOutput', 'SetVersion', 'Finalize')
io = IO()
if self.token == 'BindInput':
......@@ -301,6 +301,12 @@ class RegisterLiteKernelParser(SyntaxParser):
elif self.token == 'BindOutput':
eat_io(False, io)
k.outputs.append(io)
elif self.token == 'SetVersion':
self.eat_left_parentheses()
self.eat_str()
self.version = self.token
self.eat_right_parentheses()
self.eat_spaces()
else:
self.eat_left_parentheses()
self.eat_right_parentheses()
......
......@@ -21,6 +21,7 @@
#include "lite/utils/hash.h"
#include "lite/utils/io.h"
#include "lite/utils/macros.h"
#include "lite/utils/string.h"
#include "lite/utils/varient.h"
#ifdef LITE_ON_TINY_PUBLISH
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册