提交 10db87ed 编写于 作者: B baolei.an

pass code style check

上级 36004a9a
......@@ -6,5 +6,5 @@ endif()
lite_cc_library(arena_framework SRCS framework.cc DEPS program gtest)
if((NOT LITE_WITH_OPENCL) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${npu_kernels} ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${bm_kernels} ${npu_kernels} ${xpu_kernels} ${x86_kernels} ${cuda_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
......@@ -12,54 +12,51 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_if.h>
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/bm/bridges/graph.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace bm {
int ActConverter(void* ctx, OpLite* op, KernelBase* kernel){
int ActConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output =
scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
int i_output_shape_data[output_dims.size()];
const int64_t* x_shape_data =
const_cast<const int64_t*>(&x_dims.data()[0]);
const int64_t* output_shape_data =
const_cast<const int64_t*>(&output_dims.data()[0]);
std::vector<int32_t> i_x_shape_data(x_dims.size());
std::vector<int32_t> i_output_shape_data(output_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
CHECK(op_type == "relu");
CHECK_EQ(op_type, "relu");
add_relu_layer(graph->GetCompilerHandle(),
const_cast<const int*>(i_x_shape_data),
x_dims.size(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(i_output_shape_data),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
0.f,
-1.f);
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(&i_output_shape_data[0]),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
0.f,
-1.f);
graph->AddNode(output_var_name);
return SUCCESS;
}
......@@ -69,4 +66,5 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel){
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(BM, relu, paddle::lite::subgraph::bm::ActConverter);
REGISTER_SUBGRAPH_BRIDGE(relu, kBM,
paddle::lite::subgraph::bm::ActConverter);
......@@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_if.h>
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
......@@ -30,49 +30,41 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_name = lite::subgraph::bm::UniqueName(op_type);
// input
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
const int64_t* x_shape_data = const_cast<const int64_t*>(&x_dims.data()[0]);
std::vector<int32_t> i_x_shape_data(x_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
int channel_size = x_dims[1];
auto scale_var_name = op_info->Input("Scale").front();
auto scale = scope->FindVar(scale_var_name)->GetMutable<lite::Tensor>();
auto bias_var_name = op_info->Input("Bias").front();
auto bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
auto mean_var_name = op_info->Input("Mean").front();
auto mean = scope->FindVar(mean_var_name)->GetMutable<lite::Tensor>();
auto variance_var_name = op_info->Input("Variance").front();
auto variance = scope->FindVar(variance_var_name)->GetMutable<lite::Tensor>();
auto variance =
scope->FindVar(variance_var_name)->GetMutable<lite::Tensor>();
// output
auto output_var_name = op_info->Output("Y").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_output_shape_data[output_dims.size()];
const int64_t* output_shape_data =
const_cast<const int64_t*>(&output_dims.data()[0]);
std::vector<int32_t> i_output_shape_data(output_dims.size());
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
auto epsilon = op_info->GetAttr<float>("epsilon");
auto unique_bn_out_name = lite::subgraph::bm::UniqueName("batch_norm_out");
auto* scale_data = scale->mutable_data<float>();
auto* bias_data = bias->mutable_data<float>();
auto* mean_data = mean->mutable_data<float>();
auto* variance_data = variance->mutable_data<float>();
for (int c = 0; c < channel_size; c++) {
float inv_scale = 1.f / (std::sqrt(variance_data[c] + epsilon));
bias_data[c] = bias_data[c] - inv_scale * scale_data[c] * mean_data[c];
......@@ -83,17 +75,15 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
int **shape = new int *[input_num];
int *dim = new int[input_num];
const char **name = new const char *[input_num];
name[0] = static_cast<const char*>(x_var_name.c_str());
dim[0] = x_dims.size();
shape[0] = i_x_shape_data;
shape[0] = &i_x_shape_data[0];
add_scale_layer(graph->GetCompilerHandle(),
input_num,
shape,
dim,
name,
const_cast<const int*>(i_output_shape_data),
const_cast<const int*>(&i_output_shape_data[0]),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
static_cast<const char*>(unique_op_name.c_str()),
......@@ -102,7 +92,6 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
1,
1,
1);
delete [] shape;
delete [] name;
delete [] dim;
......@@ -116,4 +105,5 @@ int BatchNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(BM, batch_norm, paddle::lite::subgraph::bm::BatchNormConverter);
REGISTER_SUBGRAPH_BRIDGE(batch_norm, kBM,
paddle::lite::subgraph::bm::BatchNormConverter);
......@@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_if.h>
#include "lite/operators/conv_op.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
......@@ -26,13 +27,11 @@ namespace bm {
int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_name = lite::subgraph::bm::UniqueName(op_type);
auto input_var_name = op_info->Input("Input").front();
auto input = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
auto input_dims = input->dims();
......@@ -42,11 +41,9 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto filter_var_name = op_info->Input("Filter").front();
auto filter = scope->FindVar(filter_var_name)->GetMutable<lite::Tensor>();
auto filter_dims = filter->dims();
CHECK(input_dims.size() == 4);
CHECK(output_dims.size() == 4);
CHECK(filter_dims.size() == 4);
CHECK_EQ(input_dims.size(), 4);
CHECK_EQ(output_dims.size(), 4);
CHECK_EQ(filter_dims.size(), 4);
bool has_bias = lite::subgraph::bm::HasInputArg(op_info, scope, "Bias");
float* bias_data = nullptr;
if (has_bias) {
......@@ -54,33 +51,31 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto* bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
bias_data = static_cast<float*>(bias->mutable_data<float>());
}
const long int* input_shape_data = const_cast<const long int*>(&input_dims.data()[0]);
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_input_shape_data[input_dims.size()];
int i_output_shape_data[output_dims.size()];
const int64_t* input_shape_data =
const_cast<const int64_t*>(&input_dims.data()[0]);
const int64_t* output_shape_data =
const_cast<const int64_t*>(&output_dims.data()[0]);
std::vector<int32_t> i_input_shape_data(input_dims.size());
std::vector<int32_t> i_output_shape_data(output_dims.size());
for (size_t i = 0; i < input_dims.size(); i++) {
i_input_shape_data[i] = static_cast<int>(input_shape_data[i]);
}
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
const float* filter_data = const_cast<const float*>(filter->mutable_data<float>());
const float* filter_data =
const_cast<const float*>(filter->mutable_data<float>());
auto groups = op_info->GetAttr<int>("groups");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
auto strides = op_info->GetAttr<std::vector<int>>("strides");
auto dilations = op_info->GetAttr<std::vector<int>>("dilations");
add_conv_layer(graph->GetCompilerHandle(),
const_cast<const int*>(i_input_shape_data),
const_cast<const int*>(&i_input_shape_data[0]),
input_dims.size(),
static_cast<const char*>(input_var_name.c_str()),
const_cast<const int*>(i_output_shape_data),
const_cast<const int*>(&i_output_shape_data[0]),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
static_cast<const char*>(unique_op_name.c_str()),
......@@ -107,4 +102,5 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(BM, conv2d, paddle::lite::subgraph::bm::ConvConverter);
REGISTER_SUBGRAPH_BRIDGE(conv2d, kBM,
paddle::lite::subgraph::bm::ConvConverter);
......@@ -11,13 +11,12 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_if.h>
#include <bmcompiler_if_lite.h>
#include <bmcompiler_defs.h>
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "bmcompiler_if.h"
#include "bmcompiler_if_lite.h"
#include "bmcompiler_defs.h"
namespace paddle {
namespace lite {
......@@ -27,111 +26,106 @@ namespace bm {
int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
// input
const int input_num = 2;
int **shape = new int *[input_num];
int *dim = new int[input_num];
const char **name = new const char *[input_num];
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
name[0] = static_cast<const char*>(x_var_name.c_str());
dim[0] = x_dims.size();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
const int64_t* x_shape_data =
const_cast<const int64_t*>(&x_dims.data()[0]);
std::vector<int32_t> i_x_shape_data(x_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
shape[0] = i_x_shape_data;
shape[0] = &i_x_shape_data[0];
auto y_var_name = op_info->Input("Y").front();
auto y = scope->FindVar(y_var_name)->GetMutable<lite::Tensor>();
auto y_dims = y->dims();
name[1] = static_cast<const char*>(y_var_name.c_str());
dim[1] = y_dims.size();
const long int* y_shape_data = const_cast<const long int*>(&y_dims.data()[0]);
int i_y_shape_data[y_dims.size()];
const int64_t* y_shape_data =
const_cast<const int64_t*>(&y_dims.data()[0]);
std::vector<int32_t> i_y_shape_data(y_dims.size());
for (size_t i = 0; i < y_dims.size(); i++) {
i_y_shape_data[i] = static_cast<int>(y_shape_data[i]);
}
shape[1] = i_y_shape_data;
shape[1] = &i_y_shape_data[0];
bool y_is_const = !graph->HasNode(y_var_name);
// output
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output =
scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_output_shape_data[output_dims.size()];
const int64_t* output_shape_data =
const_cast<const int64_t*>(&output_dims.data()[0]);
std::vector<int32_t> i_output_shape_data(output_dims.size());
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
if (y_is_const) {
CHECK(op_type == "elementwise_add");
CHECK_EQ(op_type, "elementwise_add");
}
int op_code{-1};
float coeff[2] = {1.f, 1.f};
if (op_type == "elementwise_mul") {
op_code = 0;
} else if (op_type == "elementwise_add") {
op_code = 1;
} else if(op_type == "elementwise_sub") {
} else if (op_type == "elementwise_sub") {
op_code = 1;
coeff[1] = -1.f;
} else {
LOG(FATAL) << "UNSUPPORTED ELTWISE OPERATION: " << op_type;
}
if (!y_is_const) {
add_eltwise_layer(graph->GetCompilerHandle(),
input_num,
shape,
dim,
name,
const_cast<const int*>(i_output_shape_data),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
op_code,
coeff);
input_num,
shape,
dim,
name,
const_cast<const int*>(&i_output_shape_data[0]),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
op_code,
coeff);
} else {
const float* y_data = const_cast<const float*>(y->mutable_data<float>());
const float* x_data = const_cast<const float*>(x->mutable_data<float>());
const float* y_data =
const_cast<const float*>(y->mutable_data<float>());
const float* x_data =
const_cast<const float*>(x->mutable_data<float>());
bm_add_const_tensor(graph->GetCompilerHandle(),
name[1],
shape[0],
dim[0],
static_cast<bm_data_type_t>(DTYPE_FP32),
static_cast<const void*>(y_data));
name[1],
shape[0],
dim[0],
static_cast<bm_data_type_t>(DTYPE_FP32),
static_cast<const void*>(y_data));
add_binary_layer_v2(graph->GetCompilerHandle(),
name[0],
shape[0],
dim[0],
0,
static_cast<const float*>(x_data),
name[1],
shape[0],
dim[0],
0,
static_cast<const float*>(y_data),
static_cast<const char*>(output_var_name.c_str()),
0);
name[0],
shape[0],
dim[0],
0,
static_cast<const float*>(x_data),
name[1],
shape[0],
dim[0],
0,
static_cast<const float*>(y_data),
static_cast<const char*>(output_var_name.c_str()),
0);
}
delete [] shape;
delete [] name;
delete [] dim;
graph->AddNode(output_var_name);
return SUCCESS;
}
......@@ -141,4 +135,5 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(BM, elementwise_add, paddle::lite::subgraph::bm::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(elementwise_add, kBM,
paddle::lite::subgraph::bm::ElementwiseConverter);
......@@ -11,11 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_if.h>
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
......@@ -30,32 +29,30 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_name = lite::subgraph::bm::UniqueName(op_type);
// only support y is const
// input
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
const int64_t* x_shape_data =
const_cast<const int64_t*>(&x_dims.data()[0]);
std::vector<int> i_x_shape_data(x_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
// add reshape layer
int i_x_reshape_shape_data[2];
for (size_t i = 0; i < 2; i++) {
i_x_reshape_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
int reshape_param[] = {0, -1};
auto unique_op_reshape_name = lite::subgraph::bm::UniqueName(op_type + "_reshape");
auto unique_op_reshape_name =
lite::subgraph::bm::UniqueName(op_type + "_reshape");
add_reshape_layer(graph->GetCompilerHandle(),
const_cast<const int*>(i_x_shape_data),
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(i_x_reshape_shape_data),
const_cast<const int*>(&i_x_reshape_shape_data[0]),
2,
static_cast<const char*>(unique_op_reshape_name.c_str()),
const_cast<const int*>(reshape_param));
......@@ -63,32 +60,30 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto y_var_name = op_info->Input("Y").front();
auto y = scope->FindVar(y_var_name)->GetMutable<lite::Tensor>();
auto y_dims = y->dims();
// output
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_output_shape_data[output_dims.size()];
const int64_t* output_shape_data =
const_cast<const int64_t*>(&output_dims.data()[0]);
std::vector<int32_t> i_output_shape_data(output_dims.size());
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
add_fc_layer(graph->GetCompilerHandle(),
const_cast<const int*>(i_x_reshape_shape_data),
2,
static_cast<const char*>(unique_op_reshape_name.c_str()),
const_cast<const int*>(i_output_shape_data),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
static_cast<const char*>(unique_op_name.c_str()),
i_x_reshape_shape_data[1],
i_output_shape_data[1],
static_cast<const float*>(y->mutable_data<float>()),
nullptr,
0,
0);
const_cast<const int*>(&i_x_reshape_shape_data[0]),
2,
static_cast<const char*>(unique_op_reshape_name.c_str()),
const_cast<const int*>(&i_output_shape_data[0]),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
static_cast<const char*>(unique_op_name.c_str()),
i_x_reshape_shape_data[1],
i_output_shape_data[1],
static_cast<const float*>(y->mutable_data<float>()),
nullptr,
0,
0);
graph->AddNode(output_var_name);
return SUCCESS;
}
......@@ -98,4 +93,5 @@ int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(BM, mul, paddle::lite::subgraph::bm::MulConverter);
REGISTER_SUBGRAPH_BRIDGE(mul,
kBM, paddle::lite::subgraph::bm::MulConverter);
......@@ -14,11 +14,11 @@
#pragma once
USE_SUBGRAPH_BRIDGE(BM, relu);
USE_SUBGRAPH_BRIDGE(BM, conv2d);
USE_SUBGRAPH_BRIDGE(BM, elementwise_add);
USE_SUBGRAPH_BRIDGE(BM, pool2d);
USE_SUBGRAPH_BRIDGE(BM, softmax);
USE_SUBGRAPH_BRIDGE(BM, mul);
USE_SUBGRAPH_BRIDGE(BM, batch_norm);
USE_SUBGRAPH_BRIDGE(BM, scale);
USE_SUBGRAPH_BRIDGE(relu, kBM);
USE_SUBGRAPH_BRIDGE(conv2d, kBM);
USE_SUBGRAPH_BRIDGE(elementwise_add, kBM);
USE_SUBGRAPH_BRIDGE(pool2d, kBM);
USE_SUBGRAPH_BRIDGE(softmax, kBM);
USE_SUBGRAPH_BRIDGE(mul, kBM);
USE_SUBGRAPH_BRIDGE(batch_norm, kBM);
USE_SUBGRAPH_BRIDGE(scale, kBM);
......@@ -11,11 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_if.h>
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
......@@ -30,69 +29,65 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_name = lite::subgraph::bm::UniqueName(op_type);
// input
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
const int64_t* x_shape_data =
const_cast<const int64_t*>(&x_dims.data()[0]);
std::vector<int32_t> i_x_shape_data(x_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
// output
int *shape[1];
int dim[1];
int32_t *shape[1];
int32_t dim[1];
const char *name[1];
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_output_shape_data[output_dims.size()];
const int64_t* output_shape_data =
const_cast<const int64_t*>(&output_dims.data()[0]);
std::vector<int32_t> i_output_shape_data(output_dims.size());
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
shape[0] = i_output_shape_data;
shape[0] = &i_output_shape_data[0];
name[0] = static_cast<const char*>(output_var_name.c_str());
dim[0] = output_dims.size();
auto pooling_type = op_info->GetAttr<std::string>("pooling_type");
CHECK(pooling_type == "max" || pooling_type == "avg");
auto ksize = op_info->GetAttr<std::vector<int>>("ksize");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
auto strides = op_info->GetAttr<std::vector<int>>("strides");
auto global_pooling = op_info->GetAttr<bool>("global_pooling");
auto ceil_mode = op_info->GetAttr<bool>("ceil_mode");
bool average_exclusive = false;
if (pooling_type == "avg") {
average_exclusive = op_info->GetAttr<bool>("exclusive");
}
add_pooling_layer(graph->GetCompilerHandle(),
const_cast<const int*>(i_x_shape_data),
x_dims.size(),
static_cast<const char*>(x_var_name.c_str()),
1,
shape,
dim,
name,
ksize[0],
ksize[1],
paddings[0],
paddings[0],
paddings[1],
paddings[1],
strides[0],
strides[1],
(ksize[0] > 1 && ksize[1] > 1) && pooling_type == "max" ? 0 : 1,
static_cast<int>(average_exclusive),
static_cast<int>(global_pooling),
static_cast<int>(ceil_mode),
static_cast<const char*>(unique_op_name.c_str()),
nullptr);
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
static_cast<const char*>(x_var_name.c_str()),
1,
shape,
dim,
name,
ksize[0],
ksize[1],
paddings[0],
paddings[0],
paddings[1],
paddings[1],
strides[0],
strides[1],
(ksize[0] > 1 && ksize[1] > 1) && pooling_type == "max" ? 0 : 1,
static_cast<int>(average_exclusive),
static_cast<int>(global_pooling),
static_cast<int>(ceil_mode),
static_cast<const char*>(unique_op_name.c_str()),
nullptr);
graph->AddNode(output_var_name);
return SUCCESS;
}
......@@ -101,5 +96,5 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(BM, pool2d, paddle::lite::subgraph::bm::PoolConverter);
REGISTER_SUBGRAPH_BRIDGE(pool2d, kBM,
paddle::lite::subgraph::bm::PoolConverter);
......@@ -12,11 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_op_code.h>
#include <bmcompiler_if.h>
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "bmcompiler_op_code.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
......@@ -32,50 +33,41 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_name = lite::subgraph::bm::UniqueName(op_type);
// input
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
const int64_t* x_shape_data =
const_cast<const int64_t*>(&x_dims.data()[0]);
std::vector<int32_t> i_x_shape_data(x_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
// output
auto output_var_name = op_info->Output("Out").front();
auto scale = op_info->GetAttr<float>("scale");
auto bias = op_info->GetAttr<float>("bias");
auto bias_after_scale = op_info->GetAttr<bool>("bias_after_scale");
if (!bias_after_scale) {
bias *= scale;
}
auto unique_op_scale_name = lite::subgraph::bm::UniqueName(op_type);
add_const_binary_layer(graph->GetCompilerHandle(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(i_x_shape_data),
x_dims.size(),
scale,
static_cast<const char*>(unique_op_scale_name.c_str()),
BINARY_MUL,
0);
auto unique_op_scale_name = lite::subgraph::bm::UniqueName(op_type);
add_const_binary_layer(graph->GetCompilerHandle(),
static_cast<const char*>(unique_op_scale_name.c_str()),
const_cast<const int*>(i_x_shape_data),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
bias,
static_cast<const char*>(output_var_name.c_str()),
BINARY_ADD,
scale,
static_cast<const char*>(unique_op_scale_name.c_str()),
BINARY_MUL,
0);
add_const_binary_layer(graph->GetCompilerHandle(),
static_cast<const char*>(unique_op_scale_name.c_str()),
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
bias,
static_cast<const char*>(output_var_name.c_str()),
BINARY_ADD,
0);
graph->AddNode(output_var_name);
return SUCCESS;
}
......@@ -85,4 +77,5 @@ int ScaleConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(BM, scale, paddle::lite::subgraph::bm::ScaleConverter);
REGISTER_SUBGRAPH_BRIDGE(scale, kBM,
paddle::lite::subgraph::bm::ScaleConverter);
......@@ -11,11 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_if.h>
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
......@@ -28,46 +27,44 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
// input
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
for (size_t i = 0; i < x_dims.size(); i++) {
const int64_t* x_shape_data =
const_cast<const int64_t*>(&x_dims.data()[0]);
size_t length = x_dims.size();
std::vector<int32_t> i_x_shape_data(length);
for (size_t i = 0; i < length; i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
// output
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_output_shape_data[output_dims.size()];
for (size_t i = 0; i < output_dims.size(); i++) {
const int64_t* output_shape_data =
const_cast<const int64_t*>(&output_dims.data()[0]);
length = output_dims.size();
std::vector<int32_t> i_output_shape_data(length);
for (size_t i = 0; i < length; i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
auto axis = op_info->GetAttr<int>("axis");
if (axis < 0) {
axis += x_dims.size();
}
int outer_num = x_dims.Slice(0, axis).production();
int inner_num = x_dims.Slice(axis + 1, x_dims.size()).production();
add_softmax_layer(graph->GetCompilerHandle(),
const_cast<const int*>(i_x_shape_data),
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(i_output_shape_data),
const_cast<const int*>(&i_output_shape_data[0]),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
inner_num,
outer_num,
x_dims[axis]);
graph->AddNode(output_var_name);
return SUCCESS;
}
......@@ -77,4 +74,5 @@ int SoftmaxConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(BM, softmax, paddle::lite::subgraph::bm::SoftmaxConverter);
REGISTER_SUBGRAPH_BRIDGE(softmax, kBM,
paddle::lite::subgraph::bm::SoftmaxConverter);
......@@ -17,6 +17,7 @@
#include <time.h>
#include <string>
#include <vector>
#include <utility>
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
#include "lite/kernels/bm/bridges/graph.h"
......@@ -34,18 +35,17 @@ int SubgraphEngine::BuildDeviceProgram() {
const auto& bridges = subgraph::Registry::Instance();
graph.CreateCompilerHandle();
auto& ctx = this->ctx_->template As<BMContext>();
for (auto& inst : origin_program_) {
auto op = inst.op();
CHECK(op);
op->CheckShape();
op->InferShape();
std::string op_type = op->op_info()->Type();
if (!bridges.Exists("BM", op_type)) {
if (!bridges.Exists(op_type, "kBM")) {
return subgraph::FAILED;
}
auto kernel = inst.kernel();
status |= bridges.Select("BM", op_type)(reinterpret_cast<void*>(&graph),
status |= bridges.Select(op_type, "kBM")(reinterpret_cast<void*>(&graph),
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) {
......@@ -54,8 +54,8 @@ int SubgraphEngine::BuildDeviceProgram() {
}
std::string net_name = "paddle_bitmain";
__bmcompile_opt(graph.GetCompilerHandle(), const_cast<char*>(net_name.c_str()), 2);
__bmcompile_opt(graph.GetCompilerHandle(),
const_cast<char*>(net_name.c_str()), 2);
void* bmodel_data = nullptr;
unsigned int data_size = 0;
bm_hd_ = static_cast<bm_handle_t>(ctx.GetHandle());
......@@ -64,32 +64,30 @@ int SubgraphEngine::BuildDeviceProgram() {
if (false == bmrt_load_bmodel_data(bmrt_hd_, bmodel_data, data_size)) {
return subgraph::FAILED;
}
bmrt_get_network_names(bmrt_hd_, &net_names_);
net_info_ = bmrt_get_network_info(bmrt_hd_, net_names_[0]);
auto &stage = net_info_->stages[0];
// input
origin_idims_.resize(input_names_.size());
origin_itensors_.resize(input_names_.size());
device_inputs_.resize(input_names_.size());
device_inputs_.resize(input_names_.size());
for (size_t i = 0; i < input_names_.size(); i++) {
origin_itensors_[i] = scope_->FindMutableTensor(input_names_[i]);
CHECK(origin_itensors_[i]);
origin_idims_[i] = origin_itensors_[i]->dims();
bm_device_mem_t* p_mem = static_cast<bm_device_mem_t*>(malloc(sizeof(bm_device_mem_t)));
origin_idims_[i] = origin_itensors_[i]->dims();
bm_device_mem_t* p_mem =
static_cast<bm_device_mem_t*>(malloc(sizeof(bm_device_mem_t)));
CHECK(p_mem != nullptr);
CHECK(bm_malloc_device_byte(bm_hd_, p_mem, origin_itensors_[i]->memory_size()) == BM_SUCCESS);
CHECK_EQ(bm_malloc_device_byte(bm_hd_,
p_mem, origin_itensors_[i]->memory_size()), BM_SUCCESS);
bmrt_tensor_with_device(&device_inputs_[i], *p_mem,
net_info_->input_dtypes[i],
stage.input_shapes[i]);
}
// output
// output
origin_odims_.resize(output_names_.size());
origin_otensors_.resize(output_names_.size());
device_outputs_.resize(output_names_.size());
for (size_t i = 0; i < output_names_.size(); i++) {
origin_otensors_[i] = scope_->FindMutableTensor(output_names_[i]);
CHECK(origin_otensors_[i]);
......@@ -97,12 +95,13 @@ int SubgraphEngine::BuildDeviceProgram() {
output_map_.insert(std::pair<std::string, int>(output_names_[i], i));
origin_otensors_[i]->mutable_data<float>();
}
for (size_t i = 0; i < output_names_.size(); i++) {
int mapping_index = output_map_.at(net_info_->output_names[i]);
bm_device_mem_t* p_mem = static_cast<bm_device_mem_t*>(malloc(sizeof(bm_device_mem_t)));
bm_device_mem_t* p_mem =
static_cast<bm_device_mem_t*>(malloc(sizeof(bm_device_mem_t)));
CHECK(p_mem != nullptr);
CHECK(bm_malloc_device_byte(bm_hd_, p_mem, origin_otensors_[mapping_index]->memory_size()) == BM_SUCCESS);
CHECK_EQ(bm_malloc_device_byte(bm_hd_,
p_mem, origin_otensors_[mapping_index]->memory_size()), BM_SUCCESS);
bmrt_tensor_with_device(&device_outputs_[i], *p_mem,
net_info_->output_dtypes[i],
stage.output_shapes[i]);
......@@ -113,14 +112,21 @@ int SubgraphEngine::BuildDeviceProgram() {
int SubgraphEngine::LaunchDeviceProgram() {
for (size_t i = 0; i < device_inputs_.size(); i++) {
bm_memcpy_s2d(bm_hd_, device_inputs_[i].device_mem, const_cast<void*>(origin_itensors_[i]->raw_data()));
bm_memcpy_s2d(bm_hd_,
device_inputs_[i].device_mem,
const_cast<void*>(origin_itensors_[i]->raw_data()));
}
bmrt_launch_tensor_ex(bmrt_hd_, net_names_[0], static_cast<const bm_tensor_t*>(&device_inputs_[0]),
net_info_->input_num, static_cast<bm_tensor_t*>(&device_outputs_[0]), net_info_->output_num, true, false);
bm_thread_sync(bm_hd_);
bmrt_launch_tensor_ex(bmrt_hd_,
net_names_[0],
static_cast<const bm_tensor_t*>(&device_inputs_[0]),
net_info_->input_num,
static_cast<bm_tensor_t*>(&device_outputs_[0]),
net_info_->output_num, true, false);
bm_thread_sync(bm_hd_);
for (size_t i = 0; i < device_outputs_.size(); i++) {
bm_memcpy_d2s(bm_hd_, const_cast<void*>(origin_otensors_[i]->raw_data()), device_outputs_[i].device_mem);
bm_memcpy_d2s(bm_hd_,
const_cast<void*>(origin_otensors_[i]->raw_data()),
device_outputs_[i].device_mem);
}
return 0;
}
......
if(NOT LITE_WITH_NPU AND NOT LITE_WITH_XPU)
if(NOT LITE_WITH_NPU AND NOT LITE_WITH_XPU AND NOT LITE_WITH_BM)
return()
endif()
......
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(test_kernel_scale_compute SRCS scale_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_power_compute SRCS power_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_shuffle_channel_compute SRCS shuffle_channel_compute_test.cc DEPS arena_framework ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......@@ -35,36 +35,36 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(test_kernel_pool_compute SRCS pool_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_value_compute SRCS assign_value_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_box_clip_compute SRCS box_clip_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_prod_compute SRCS reduce_prod_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_roi_align_compute SRCS roi_align_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_search_aligned_mat_mul_compute SRCS search_aligned_mat_mul_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_search_seq_fc_compute SRCS search_seq_fc_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_lookup_table_compute SRCS lookup_table_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_gather_compute SRCS gather_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS ${bm_kernels} arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${bm_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_value_compute SRCS assign_value_compute_test.cc DEPS arena_framework ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_box_clip_compute SRCS box_clip_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_prod_compute SRCS reduce_prod_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_roi_align_compute SRCS roi_align_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_search_aligned_mat_mul_compute SRCS search_aligned_mat_mul_compute_test.cc DEPS arena_framework ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_search_seq_fc_compute SRCS search_seq_fc_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${bm_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_lookup_table_compute SRCS lookup_table_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_gather_compute SRCS gather_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_negative_compute SRCS negative_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_bilinear_interp_compute SRCS bilinear_interp_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_nearest_interp_compute SRCS nearest_interp_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_shape_compute SRCS shape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_crop_compute SRCS crop_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_expand_compute SRCS sequence_expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_slice_compute SRCS slice_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_negative_compute SRCS negative_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_bilinear_interp_compute SRCS bilinear_interp_compute_test.cc DEPS arena_framework ${xpu_kernels} ${bm_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_nearest_interp_compute SRCS nearest_interp_compute_test.cc DEPS arena_framework ${xpu_kernels} ${bm_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_shape_compute SRCS shape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_crop_compute SRCS crop_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${bm_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_expand_compute SRCS sequence_expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${bm_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_slice_compute SRCS slice_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${bm_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
#!/bin/bash
set -ex
# global variables with default value
BM_SDK_ROOT="$(pwd)/../BM_SDK" # BM SDK
TARGET_NAME="BM1682" # default target
BUILD_EXTRA=OFF # ON(with sequence ops)/OFF
WITH_TESTING=ON # ON/OFF
function print_usage {
echo -e "\nUSAGE:"
echo
echo "----------------------------------------"
echo -e "--bm_sdk_root=<bm sdk directory>"
echo -e "--target_name=<target name>"
echo "----------------------------------------"
echo
}
# readonly variables with default value
readonly CMAKE_COMMON_OPTIONS="-DWITH_LITE=ON \
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=OFF \
-DWITH_PYTHON=OFF \
-DLITE_WITH_ARM=OFF"
readonly NUM_CORES_FOR_COMPILE=${LITE_BUILD_THRLITE_BUILD_THREADSEADS:-1}
readonly THIRDPARTY_TAR=https://paddle-inference-dist.bj.bcebos.com/PaddleLite/third-party-05b862.tar.gz
readonly workspace=$(pwd)
function prepare_thirdparty {
if [ ! -d $workspace/third-party -o -f $workspace/third-party-05b862.tar.gz ]; then
rm -rf $workspace/third-party
if [ ! -f $workspace/third-party-05b862.tar.gz ]; then
wget $THIRDPARTY_TAR
fi
tar xzf third-party-05b862.tar.gz
else
git submodule update --init --recursive
fi
}
# for code gen, a source file is generated after a test, but is dependended by some targets in cmake.
# here we fake an empty file to make cmake works.
function prepare_workspace {
# in build directory
# 1. Prepare gen_code file
GEN_CODE_PATH_PREFIX=lite/gen_code
mkdir -p ./${GEN_CODE_PATH_PREFIX}
touch ./${GEN_CODE_PATH_PREFIX}/__generated_code__.cc
# 2.Prepare debug tool
DEBUG_TOOL_PATH_PREFIX=lite/tools/debug
mkdir -p ./${DEBUG_TOOL_PATH_PREFIX}
cp ../${DEBUG_TOOL_PATH_PREFIX}/analysis_tool.py ./${DEBUG_TOOL_PATH_PREFIX}/
# clone submodule
# git submodule update --init --recursive
prepare_thirdparty
}
function build_bm {
build_dir=${workspace}/build.lite.bm
mkdir -p $build_dir
cd $build_dir
prepare_workspace
cmake .. \
${CMAKE_COMMON_OPTIONS} \
-DWITH_GPU=OFF \
-DWITH_MKLDNN=OFF \
-DLITE_WITH_X86=ON \
-DWITH_MKL=ON \
-DLITE_BUILD_EXTRA=ON \
-DLITE_WITH_XPU=OFF \
-DLITE_WITH_BM=ON \
-DWITH_TESTING=${WITH_TESTING} \
-DBM_SDK_ROOT=${BM_SDK_ROOT}
make -j$NUM_CORES_FOR_COMPILE
cd -
echo "Done"
}
function main {
# Parse command line.
for i in "$@"; do
case $i in
--target_name=*)
TARGET_NAME="${i#*=}"
shift
;;
--bm_sdk_root=*)
BM_SDK_ROOT="${i#*=}"
shift
;;
bm)
build_bm
shift
;;
*)
# unknown option
print_usage
exit 1
;;
esac
done
}
main $@
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册