提交 8d9cc823 编写于 作者: J jiaopu 提交者: jackzhang235

Add cast op

上级 aa70ff8e
......@@ -22,6 +22,7 @@ lite_cc_library(subgraph_bridge_transpose_op_mlu SRCS transpose_op.cc DEPS ${sub
lite_cc_library(subgraph_bridge_dropout_op_mlu SRCS dropout_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_slice_op_mlu SRCS slice_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_split_op_mlu SRCS split_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_cast_op_mlu SRCS cast_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_argmax_op_mlu SRCS argmax_op.cc DEPS ${subgraph_bridge_deps_mlu})
lite_cc_library(subgraph_bridge_squeeze_op_mlu SRCS squeeze_op.cc DEPS ${subgraph_bridge_deps_mlu})
set(mlu_subgraph_bridges
......@@ -42,6 +43,7 @@ set(mlu_subgraph_bridges
subgraph_bridge_dropout_op_mlu
subgraph_bridge_slice_op_mlu
subgraph_bridge_split_op_mlu
subgraph_bridge_cast_op_mlu
subgraph_bridge_argmax_op_mlu
subgraph_bridge_squeeze_op_mlu
CACHE INTERNAL "mlu_subgraph_bridges")
......@@ -69,6 +71,7 @@ lite_cc_test(test_transpose_converter_mlu SRCS transpose_op_test.cc DEPS scope o
lite_cc_test(test_dropout_converter_mlu SRCS dropout_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_slice_converter_mlu SRCS slice_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_split_converter_mlu SRCS split_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_cast_converter_mlu SRCS cast_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_argmax_converter_mlu SRCS argmax_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
lite_cc_test(test_squeeze_converter_mlu SRCS squeeze_op_test.cc DEPS scope optimizer target_wrapper_host model_parser program ${mlu_subgraph_bridges} subgraph_compute_mlu subgraph_test_helper_mlu)
if (LITE_BUILD_EXTRA)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/mlu/bridges/graph.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
int CastConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[MLU] Converting " + op_type + "...";
auto x_var_name = op_info->Input("X").front();
auto out_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(out_var_name)->GetMutable<Tensor>();
auto output_dims = output->dims().Vectorize();
auto in_dtype = op_info->GetAttr<int>("in_dtype");
auto out_dtype = op_info->GetAttr<int>("out_dtype");
CHECK(graph->HasNode(x_var_name));
auto x_tensor = graph->GetNode(x_var_name);
cnmlDataType_t data_type;
if (out_dtype == 4) {
data_type = CNML_DATA_FLOAT16;
} else if (out_dtype == 5) {
data_type = CNML_DATA_FLOAT32;
} else {
CHECK(0) << "Unsupported data_type";
}
auto output_tensor = graph->AddNode(
out_var_name, output_dims, CNML_TENSOR, CNML_NCHW, data_type);
cnmlCastType_t cast_type;
if (in_dtype == 4 && out_dtype == 5) {
cast_type = CNML_CAST_FLOAT16_TO_FLOAT32;
} else if (in_dtype == 5 && out_dtype == 4) {
cast_type = CNML_CAST_FLOAT32_TO_FLOAT16;
} else {
CHECK(0) << "Unsupported cast type";
}
cnmlBaseOp_t cast_op;
CNML_CALL(cnmlCreateCastOp(&cast_op,
cast_type,
x_tensor->mlu_tensor(),
output_tensor->mlu_tensor()));
graph->FuseOp(cast_op);
CNML_CALL(cnmlDestroyBaseOp(&cast_op));
return SUCCESS;
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(cast,
kMLU,
paddle::lite::subgraph::mlu::CastConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/cast_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/mlu/bridges/test_helper.h"
#include "lite/kernels/mlu/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace mlu {
void test_cast_FP16_to_FP32(std::vector<int64_t> shape) {
// prepare input&output variables
std::string x_var_name = "x";
std::string out_var_name = "out";
Scope scope;
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
x->Resize(DDim(shape));
auto* x_data = x->mutable_data<paddle::lite::fluid::float16>();
// initialize input&output data
for (int i = 0; i < x->dims().production(); i++) {
x_data[i] = static_cast<paddle::lite::fluid::float16>(i);
}
// initialize op desc
int in_dtype = 4, out_dtype = 5;
cpp::OpDesc opdesc;
opdesc.SetType("cast");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("in_dtype", in_dtype);
opdesc.SetAttr("out_dtype", out_dtype);
auto op = CreateOp<operators::CastOp>(opdesc, &scope);
Tensor data;
data.Resize(DDim(shape));
auto* copy_data = data.mutable_data<paddle::lite::fluid::float16>();
data.CopyDataFrom(*x);
x->set_precision(paddle::lite_api::PrecisionType::kFP16);
LaunchOp(op, {x_var_name}, {out_var_name});
// compare results
auto* out_data = out->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data[i], static_cast<double>(copy_data[i]), 5e-4);
}
}
void test_cast_FP32_to_FP16(std::vector<int64_t> shape) {
// prepare input&output variables
std::string x_var_name = "x";
std::string out_var_name = "out";
Scope scope;
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
x->Resize(DDim(shape));
auto* x_data = x->mutable_data<float>();
// initialize input&output data
for (int i = 0; i < x->dims().production(); i++) {
x_data[i] = static_cast<float>(i);
}
// initialize op desc
int in_dtype = 5, out_dtype = 4;
cpp::OpDesc opdesc;
opdesc.SetType("cast");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("in_dtype", in_dtype);
opdesc.SetAttr("out_dtype", out_dtype);
auto op = CreateOp<operators::CastOp>(opdesc, &scope);
Tensor data;
data.Resize(DDim(shape));
auto* copy_data = data.mutable_data<float>();
data.CopyDataFrom(*x);
x->set_precision(paddle::lite_api::PrecisionType::kFloat);
LaunchOp(op, {x_var_name}, {out_var_name});
// compare results
auto* out_data = out->mutable_data<paddle::lite::fluid::float16>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(static_cast<double>(out_data[i]), copy_data[i], 5e-4);
}
}
TEST(MLUBridges, cast) {
test_cast_FP16_to_FP32({2, 3, 4, 5});
test_cast_FP16_to_FP32({6, 3, 2, 5});
test_cast_FP32_to_FP16({2, 3, 4, 5});
test_cast_FP32_to_FP16({6, 3, 2, 5});
}
} // namespace mlu
} // namespace subgraph
} // namespace lite
} // namespace paddle
USE_SUBGRAPH_BRIDGE(cast, kMLU);
......@@ -34,6 +34,7 @@ USE_SUBGRAPH_BRIDGE(elementwise_mul, kMLU);
USE_SUBGRAPH_BRIDGE(dropout, kMLU);
USE_SUBGRAPH_BRIDGE(argmax, kMLU);
USE_SUBGRAPH_BRIDGE(split, kMLU);
USE_SUBGRAPH_BRIDGE(cast, kMLU);
USE_SUBGRAPH_BRIDGE(slice, kMLU);
USE_SUBGRAPH_BRIDGE(squeeze, kMLU);
USE_SUBGRAPH_BRIDGE(squeeze2, kMLU);
......
......@@ -83,6 +83,23 @@ void LaunchOp(const std::shared_ptr<lite::OpLite> op,
temp_input.mutable_data<int>(),
sizeof(int) * input_tensor->dims().production(),
CNRT_MEM_TRANS_DIR_HOST2DEV));
} else if (fp_type == CNML_DATA_FLOAT16) {
auto input_node = graph.AddNode(
input_name,
input_tensor->dims().Vectorize(),
CNML_TENSOR,
CNML_NCHW,
fp_type,
reinterpret_cast<void*>(
input_tensor->mutable_data<paddle::lite::fluid::float16>(
TARGET(kMLU))));
CHECK(input_node);
CNRT_CHECK(
cnrtMemcpy(input_tensor->mutable_data<paddle::lite::fluid::float16>(),
temp_input.mutable_data<paddle::lite::fluid::float16>(),
sizeof(paddle::lite::fluid::float16) *
input_tensor->dims().production(),
CNRT_MEM_TRANS_DIR_HOST2DEV));
} else {
auto input_node =
graph.AddNode(input_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册