未验证 提交 111c0b6d 编写于 作者: Y Yanzhan Yang 提交者: GitHub

fix elementwise_add && fix run.py && add namespace change script (#1898)

上级 788238b0
...@@ -25,17 +25,27 @@ bool ElementwiseAddKernel<GPU_CL, float>::Init( ...@@ -25,17 +25,27 @@ bool ElementwiseAddKernel<GPU_CL, float>::Init(
DLOG << "-----init add-----"; DLOG << "-----init add-----";
CLImage *bias = CLImage *bias =
reinterpret_cast<CLImage *>(const_cast<CLImage *>(param->InputY())); reinterpret_cast<CLImage *>(const_cast<CLImage *>(param->InputY()));
if (!bias->isInit()) {
bias->InitNormalCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
}
DLOG << " bias: " << *bias;
if (bias->dims().size() == 4) { if (bias->dims().size() == 4) {
if (!bias->isInit()) {
bias->InitNormalCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
}
DLOG << " bias: " << *bias;
this->cl_helper_.AddKernel("elementwise_add", "elementwise_add_kernel.cl"); this->cl_helper_.AddKernel("elementwise_add", "elementwise_add_kernel.cl");
} else if (param->InputY()->dims().size() == 1) { } else if (param->InputY()->dims().size() == 1) {
if (param->Axis() == param->InputX()->dims().size() - 1) { if (param->Axis() == param->InputX()->dims().size() - 1) {
if (!bias->isInit()) {
bias->InitNormalCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
}
DLOG << " bias: " << *bias;
this->cl_helper_.AddKernel("width_add", "channel_add_kernel.cl"); this->cl_helper_.AddKernel("width_add", "channel_add_kernel.cl");
} else if (param->Axis() == param->InputX()->dims().size() - 3) { } else if (param->Axis() == param->InputX()->dims().size() - 3) {
if (!bias->isInit()) {
bias->InitCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
}
DLOG << " bias: " << *bias;
this->cl_helper_.AddKernel("channel_add", "channel_add_kernel.cl"); this->cl_helper_.AddKernel("channel_add", "channel_add_kernel.cl");
} else { } else {
DLOG << "error:bias dims is error"; DLOG << "error:bias dims is error";
......
...@@ -285,8 +285,9 @@ def save_all_op_output(feed_kv=None): ...@@ -285,8 +285,9 @@ def save_all_op_output(feed_kv=None):
for fetch in fetches: for fetch in fetches:
fetch_names.append(fetch.name) fetch_names.append(fetch.name)
feed_names = feeds feed_names = feeds
for fetch_name in fetch_names: if len(output_var_filter) > 0:
output_var_filter.append(fetch_name) for fetch_name in fetch_names:
output_var_filter.append(fetch_name)
for i in range(len(ops)): for i in range(len(ops)):
op = ops[i] op = ops[i]
var_name = None var_name = None
......
#!/usr/bin/env bash
# set -o xtrace
extension=$1
convert () {
perl -pi -e "s/namespace paddle_mobile/namespace paddle_mobile_${1}/g" "${2}"
perl -pi -e "s/paddle_mobile::/paddle_mobile_${1}::/g" "${2}"
}
revert () {
perl -pi -e "s/namespace paddle_mobile_[\w]*/namespace paddle_mobile/g" "${2}"
perl -pi -e "s/paddle_mobile_[\w]*::/paddle_mobile::/g" "${2}"
}
if [[ $2 == "revert" ]]; then
for file in $(find src -name "*\.*")
do
echo "reverting ${file}"
revert $extension $file
done
for file in $(find test -name "*\.*")
do
echo "reverting ${file}"
revert $extension $file
done
else
for file in $(find src -name "*\.*")
do
echo "converting ${file}"
convert $extension $file
done
# for file in $(find test -name "*\.*")
# do
# echo "converting ${file}"
# convert $extension $file
# done
fi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册