提交 e88d6418 编写于 作者: S superjomn

add compatible for server x86 and GPU tensor.

上级 be706306
......@@ -33,7 +33,7 @@ class OpDesc {
OpDesc(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs);
OpDesc(const proto::OpDesc &desc, BlockDesc *block);
OpDesc(const proto::OpDesc &desc, BlockDesc *block = nullptr);
explicit OpDesc(BlockDesc *block) : block_(block) {}
......
if (NOT WITH_LITE)
return()
endif()
add_subdirectory(core)
add_subdirectory(x86)
add_subdirectory(host)
if(LITE_WITH_CUDA)
add_subdirectory(cuda)
endif()
add_subdirectory(cuda)
add_subdirectory(operators)
add_subdirectory(kernels)
add_subdirectory(model_parser)
......
......@@ -41,7 +41,8 @@ TEST(CXXApi, test) {
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize({100, 100});
auto* data = input_tensor->mutable_data<float>();
auto* data = TensorMutableData<float>(input_tensor, TARGET(kHost),
product(input_tensor->dims()));
for (int i = 0; i < 100 * 100; i++) {
data[i] = i;
}
......
cc_library(lite_gtest_main SRCS lite_gtest_main.cc)
cc_library(memory_lite SRCS memory.cc)
cc_library(target_wrapper_lite SRCS target_wrapper.cc)
cc_library(tensor_lite SRCS tensor.cc DEPS memory_lite target_wrapper_lite)
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cc_library(tensor_lite SRCS lite_tensor.cc DEPS memory_lite target_wrapper_lite)
else()
cc_library(tensor_lite DEPS lod_tensor)
endif()
cc_library(kernel_lite SRCS kernel.cc DEPS type_system target_wrapper_lite)
cc_library(variable_lite SRCS variable.cc)
cc_library(op_registry_lite SRCS op_registry.cc)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/compatible_tensor.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/lite/core/target_wrapper.h"
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#include "paddle/fluid/lite/core/lite_tensor.h"
#else
#include "paddle/fluid/framework/lod_tensor.h"
#endif
namespace paddle {
namespace lite {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using Tensor = details::Tensor;
using DDim = details::DDim;
#else
using Tensor = framework::LoDTensor;
using DDim = framework::DDim;
static TargetType TensorGetTarget(const Tensor &x) {
if (platform::is_gpu_place(x.place())) {
return TARGET(kCUDA);
} else if (platform::is_cpu_place(x.place())) {
return TARGET(kX86);
}
return TARGET(kUnk);
}
template <typename T>
T *TensorMutableData(Tensor *x, TargetType target, size_t size) {
if (target == TARGET(kX86) || target == TARGET(kHost)) {
return x->mutable_data<T>(platform::CPUPlace(), memory::Allocator::kDefault,
size);
} else if (target == TARGET(kCUDA)) {
return x->mutable_data<T>(platform::CUDAPlace(),
memory::Allocator::kDefault, size);
}
LOG(FATAL) << "not valid target " << TargetToStr(target);
return nullptr;
}
#endif
static int product(const DDim &dims, int start, int end) {
int res = 1;
for (int i = start; i < end; i++) {
res *= dims[i];
}
return res;
}
static DDim SliceDims(const DDim &dims, int begin, int end) {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
return DDim(dims[0] + begin, dims.begin() + end - 1);
#else
auto vec = framework::vectorize(dims);
return DDim(&vec[0] + begin, end - begin);
#endif
}
static std::vector<int64_t> DDimVectorize(const DDim &x) {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
return x;
#else
return framework::vectorize(x);
#endif
}
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
static int product(const DDim &dims) {
return std::accumulate(dims.begin(), dims.end(), 1,
[](int a, int b) { return a * b; });
}
#endif
static DDim flatten_to_2d(const DDim &dims, int col) {
return DDim({product(SliceDims(dims, 0, col)),
product(SliceDims(dims, col, dims.size()))});
}
} // namespace lite
} // namespace paddle
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/core/lite_tensor.h"
namespace paddle {
namespace lite {
......
......@@ -17,30 +17,15 @@
#include <memory>
#include <numeric>
#include <vector>
#include "paddle/fluid/lite/core/memory.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
namespace details {
using DDim = std::vector<int64_t>;
static DDim SliceDims(const DDim& dims, int begin, int end) {
return DDim(dims.begin() + begin, dims.begin() + end - 1);
}
static int product(const DDim& dims) {
return std::accumulate(dims.begin(), dims.end(), 1,
[](int a, int b) { return a * b; });
}
static int product(DDim::const_iterator begin, DDim::const_iterator end) {
return std::accumulate(begin, end, 1, [](int a, int b) { return a * b; });
}
static DDim flatten_to_2d(const DDim& dims, int col) {
return DDim({product(SliceDims(dims, 0, col)),
product(SliceDims(dims, col, dims.size()))});
}
using LoD = std::vector<std::vector<size_t>>;
......@@ -50,32 +35,32 @@ class Tensor {
Tensor() : buffer_(std::make_shared<Buffer>()) {}
template <typename T>
const T* data() const {
return static_cast<const T*>(buffer_->data());
const T *data() const {
return static_cast<const T *>(buffer_->data());
}
void Resize(const DDim& ddim) { dims_ = ddim; }
void Resize(const DDim &ddim) { dims_ = ddim; }
const DDim& dims() const { return dims_; }
const DDim &dims() const { return dims_; }
const LoD& lod() const { return lod_; }
LoD* mutable_lod() { return &lod_; }
const LoD &lod() const { return lod_; }
LoD *mutable_lod() { return &lod_; }
template <typename T>
T* mutable_data();
T *mutable_data();
template <typename T>
T* mutable_data(TargetType target);
void* mutable_data(size_t memory_size);
void* mutable_data(TargetType target, size_t memory_size);
T *mutable_data(TargetType target);
void *mutable_data(size_t memory_size);
void *mutable_data(TargetType target, size_t memory_size);
size_t memory_size() const { return memory_size_; }
bool IsInitialized() const { return buffer_->data(); }
// Other share data to this.
void ShareDataWith(const Tensor& other);
void ShareDataWith(const Tensor &other);
void CopyDataFrom(const Tensor& other);
void CopyDataFrom(const Tensor &other);
TargetType target() const { return target_; }
......@@ -88,22 +73,23 @@ class Tensor {
};
template <typename T>
T* Tensor::mutable_data() {
T *Tensor::mutable_data() {
memory_size_ = product(dims_) * sizeof(T);
buffer_->ResetLazy(target_, memory_size_);
return static_cast<T*>(buffer_->data());
return static_cast<T *>(buffer_->data());
}
template <typename T>
T* Tensor::mutable_data(TargetType target) {
T *Tensor::mutable_data(TargetType target) {
target_ = target;
memory_size_ = product(dims_) * sizeof(T);
buffer_->ResetLazy(target, memory_size());
return static_cast<T*>(buffer_->data());
return static_cast<T *>(buffer_->data());
}
std::ostream& operator<<(std::ostream& os, const DDim& dims);
std::ostream& operator<<(std::ostream& os, const Tensor& tensor);
std::ostream &operator<<(std::ostream &os, const DDim &dims);
std::ostream &operator<<(std::ostream &os, const Tensor &tensor);
} // namespace details
} // namespace lite
} // namespace paddle
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/tensor.h"
#include <gtest/gtest.h>
#include "paddle/fluid/lite/core/lite_tensor.h"
namespace paddle {
namespace lite {
......
......@@ -27,7 +27,7 @@
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
......
if(NOT LITE_WITH_CUDA)
return()
endif()
nv_library(target_wrapper_cuda SRCS target_wrapper.cc)
nv_library(cuda_blas_lite SRCS blas.cc)
set(lite_kernel_deps type_system kernel_lite op_registry_lite)
add_subdirectory(host)
add_subdirectory(arm)
if(LITE_WITH_CUDA)
add_subdirectory(cuda)
endif()
add_subdirectory(cuda)
if(NOT LITE_WITH_CUDA)
return()
endif()
nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor_lite)
cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS tensor_lite)
......
......@@ -46,11 +46,12 @@ class IoCopyHostToCudaCompute
public:
void Run() override {
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kHost) ||
param.x->target() == TARGET(kX86));
CHECK(TensorGetTarget(*param.x) == TARGET(kHost) ||
TensorGetTarget(*param.x) == TARGET(kX86));
LOG(INFO) << "copy size " << param.x->memory_size();
auto* data = param.y->mutable_data(TARGET(kCUDA), param.x->memory_size());
CopyFromHostSync(data, param.x->data<void>(), param.x->memory_size());
auto* data = TensorMutableData<int8_t>(param.y, TARGET(kCUDA),
param.x->memory_size());
CopyFromHostSync(data, param.x->data<int8_t>(), param.x->memory_size());
}
std::unique_ptr<type_infer_handler_t> GetTypeInferHandler() override {
......@@ -81,8 +82,9 @@ class IoCopyCudaToHostCompute
public:
void Run() override {
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kCUDA));
auto* data = param.y->mutable_data(TARGET(kHost), param.x->memory_size());
CHECK(TensorGetTarget(*param.x) == TARGET(kCUDA));
auto* data = TensorMutableData<int8_t>(param.y, TARGET(kHost),
param.x->memory_size());
LOG(INFO) << "copy size " << param.x->memory_size();
CopyToHostSync(data, param.x->data<void>(), param.x->memory_size());
}
......
......@@ -51,7 +51,8 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
*/
const auto& param = Param<operators::MulParam>();
param.output->mutable_data<float>(TARGET(kCUDA));
TensorMutableData<float>(param.output, TARGET(kCUDA),
product(param.output->dims()));
LOG(INFO) << "mul output memory size " << param.output->memory_size();
// mul_compute<float>(blas, x, x_h, x_w, y, y_h, y_w, out);
......
......@@ -29,17 +29,16 @@ void FcCompute::Run() {
CHECK_GE(param.input->dims().size(), 2UL);
CHECK_EQ(param.output->dims().size(), 2UL);
fc_compute_eigen(
param.input->data<float>(), // x
product(param.input->dims().begin() + param.in_num_col_dims,
param.input->dims().end()), // x_w
product(param.input->dims().begin(),
param.input->dims().begin() + param.in_num_col_dims), // x_h
param.w->data<float>(), // w
param.w->dims()[1], // w_w
param.w->dims()[0], // w_h
param.bias->data<float>(), // b
param.output->mutable_data<float>());
fc_compute_eigen(param.input->data<float>(), // x
product(param.input->dims(), 0, param.in_num_col_dims),
product(param.input->dims(), param.in_num_col_dims,
param.input->dims().size()),
param.w->data<float>(), // w
param.w->dims()[1], // w_w
param.w->dims()[0], // w_h
param.bias->data<float>(), // b
TensorMutableData<float>(param.output, TARGET(kHost),
product(param.output->dims())));
}
// TargetType FcCompute::target() const { return TARGET(kHost); }
......
......@@ -28,7 +28,7 @@ class FeedCompute
void Run() override {
auto &param = Param<operators::FeedParam>();
const Tensor &feed_item = param.feed_list->at(param.col);
param.out->CopyDataFrom(feed_item);
param.out->ShareDataWith(feed_item);
LOG(INFO) << "FEED input " << feed_item << " col " << param.col;
LOG(INFO) << "FEED output " << *param.out;
}
......
......@@ -41,19 +41,18 @@ class MulCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
void Run() override {
auto& param = Param<operators::MulParam>();
core::dim2 x_shape({product(param.x->dims().begin(),
param.x->dims().begin() + param.x_num_col_dims),
product(param.x->dims().begin() + param.x_num_col_dims,
param.x->dims().end())});
core::dim2 x_shape({product(param.x->dims(), 0, param.x_num_col_dims),
product(param.x->dims(), param.x_num_col_dims,
param.x->dims().size())});
core::dim2 y_shape({product(param.y->dims().begin(),
param.y->dims().begin() + param.x_num_col_dims),
product(param.y->dims().begin() + param.x_num_col_dims,
param.y->dims().end())});
core::dim2 y_shape({product(param.y->dims(), 0, param.y_num_col_dims),
product(param.y->dims(), param.y_num_col_dims,
param.y->dims().size())});
mul_compute_eigen(param.x->data<float>(), x_shape.x, x_shape.y, //
param.y->data<float>(), y_shape.x, y_shape.y, //
param.output->mutable_data<float>());
TensorMutableData<float>(param.output, TARGET(kHost),
product(param.output->dims())));
LOG(INFO) << "MUL x " << *param.x;
LOG(INFO) << "MUL W " << *param.y;
LOG(INFO) << "MUL out " << *param.output;
......
......@@ -24,10 +24,11 @@ namespace host {
class ReluCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
public:
void Run() override {
auto& theparam = Param<operators::ReluParam>();
auto n = product(theparam.input->dims());
const float* input = theparam.input->data<float>();
float* output = theparam.output->mutable_data<float>();
auto& param = Param<operators::ReluParam>();
auto n = product(param.input->dims());
const float* input = param.input->data<float>();
float* output = TensorMutableData<float>(param.output, TARGET(kHost),
product(param.output->dims()));
for (int i = 0; i < n; i++) {
output[i] = std::max(0.f, input[i]);
}
......
......@@ -37,7 +37,9 @@ class ScaleCompute : public KernelLite<TARGET(kHost), PRECISION(kFloat)> {
void Run() override {
auto& param = Param<operators::ScaleParam>();
scale_compute(param.x->data<float>(), param.output->mutable_data<float>(),
scale_compute(param.x->data<float>(),
TensorMutableData<float>(param.output, TARGET(kHost),
product(param.output->dims())),
product(param.x->dims()), param.scale, param.bias,
param.bias_after_scale);
}
......
......@@ -3,7 +3,7 @@ cc_test(test_model_parser_lite SRCS model_parser_test.cc DEPS model_parser_lite)
if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite var_desc_lite)
else()
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto)
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS framework_proto proto_desc)
endif(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
set(model_parser_deps variable_lite scope_lite tensor_lite scope_lite
......
......@@ -33,6 +33,7 @@ namespace paddle {
namespace lite {
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
using Attribute = lite::pb::Attribute;
using OpDesc = lite::pb::OpDesc;
using VarDesc = lite::pb::VarDesc;
#else // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
......@@ -41,5 +42,17 @@ using OpDesc = framework::OpDesc;
using VarDesc = framework::VarDesc;
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
template <typename T>
T GetAttr(const Attribute& x) {
return x.get<T>();
}
#else
template <typename T>
T GetAttr(const Attribute& x) {
return boost::get<T>(x);
}
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
} // namespace lite
} // namespace paddle
......@@ -14,8 +14,8 @@
#include "paddle/fluid/lite/model_parser/model_parser.h"
#include <fstream>
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/core/variable.h"
namespace paddle {
......@@ -59,16 +59,16 @@ void TensorFromStream(std::istream &is, lite::Tensor *tensor) {
// read tensor
std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
tensor->Resize(dims);
tensor->Resize(lite::DDim(&dims[0], dims.size()));
void *buf;
size_t size = product(tensor->dims()) * SizeOfType(desc.data_type());
// alllocate memory
switch (static_cast<int>(desc.data_type())) {
#define DO(desc, type) \
case Type::VarType_Type_##desc: \
buf = tensor->mutable_data<type>(); \
#define DO(desc, type) \
case Type::VarType_Type_##desc: \
buf = TensorMutableData<type>(tensor, TensorGetTarget(*tensor), \
product(tensor->dims()));
break;
DO(BOOL, bool);
DO(FP32, float);
......@@ -198,7 +198,8 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
auto dims = tensor.dims();
auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
std::copy(dims.begin(), dims.end(), pb_dims->begin());
auto dims_vec = DDimVectorize(dims);
std::copy(dims_vec.begin(), dims_vec.end(), pb_dims->begin());
int32_t size = desc.ByteSize();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
auto out = desc.SerializeAsString();
......@@ -210,9 +211,9 @@ void TensorToStream(std::ostream &os, const lite::Tensor &tensor) {
<< "Index overflow when writing tensor";
#ifdef LITE_WITH_CUDA
if (tensor.target() == TARGET(kCUDA)) {
if (TensorGetTarget(tensor) == TARGET(kCUDA)) {
std::unique_ptr<char> tmp_buffer(new char[size]);
TargetWrapperCuda::MemcpySync(tmp_buffer.get(), tensor.data<char>(),
TargetWrapperCuda::MemcpySync(tmp_buffer.get(), tensor.data<float>(),
tensor.memory_size(), IoDirection::DtoH);
os.write(static_cast<const char *>(tmp_buffer.get()),
static_cast<std::streamsize>(size));
......
......@@ -20,7 +20,6 @@
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/core/variable.h"
namespace paddle {
......
......@@ -58,7 +58,7 @@ bool FcOpLite::InferShape() const {
output_dims[i] = input_dims[i];
}
output_dims.back() = w_dims[1];
param_.output->Resize(output_dims);
param_.output->Resize(DDim(&output_dims[0], output_dims.size()));
// share LoD
// param_.output->set_lod(param_.input->lod());
......
......@@ -16,10 +16,10 @@
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
......@@ -57,7 +57,7 @@ class FcOpLite : public OpLite {
param_.bias = scope->FindVar(bias)->GetMutable<Tensor>();
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.in_num_col_dims = op_desc.GetAttr("in_num_col_dims").get<int>();
param_.in_num_col_dims = GetAttr<int>(op_desc.GetAttr("in_num_col_dims"));
CHECK(kernel_);
kernel_->SetParam(param_);
......
......@@ -49,7 +49,7 @@ class FeedOp : public OpLite {
// NOTE need boost here
// TODO(Superjomn) drop the need of framework::op_desc
param_.col = opdesc.GetAttr("col").get<int>();
param_.col = GetAttr<int>(opdesc.GetAttr("col"));
return true;
}
......
......@@ -43,7 +43,7 @@ class FetchOp : public OpLite {
auto* out = scope->FindVar(_out);
param_.fetch_list = out->GetMutable<std::vector<lite::Tensor>>();
param_.col = opdesc.GetAttr("col").get<int>();
param_.col = GetAttr<int>(opdesc.GetAttr("col"));
return true;
}
......
......@@ -45,7 +45,7 @@ bool MulOpLite::InferShape() const {
}
out_dims.back() = y_dims[1];
param_.output->Resize(out_dims);
param_.output->Resize(DDim(&out_dims[0], out_dims.size()));
// share LoD
// param_.output->set_lod(param_.input->lod());
......
......@@ -18,7 +18,6 @@
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
......@@ -47,8 +46,8 @@ class MulOpLite : public OpLite {
param_.y = scope->FindVar(W)->GetMutable<Tensor>();
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.x_num_col_dims = op_desc.GetAttr("x_num_col_dims").get<int>();
param_.y_num_col_dims = op_desc.GetAttr("y_num_col_dims").get<int>();
param_.x_num_col_dims = GetAttr<int>(op_desc.GetAttr("x_num_col_dims"));
param_.y_num_col_dims = GetAttr<int>(op_desc.GetAttr("y_num_col_dims"));
return true;
}
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/utils/all.h"
/*
......
......@@ -16,7 +16,6 @@
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
......
......@@ -18,7 +18,6 @@
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
......@@ -53,9 +52,10 @@ class ScaleOp : public OpLite {
param_.x = scope->FindVar(x)->GetMutable<Tensor>();
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<Tensor>();
param_.scale = op_desc.GetAttr("scale").get<float>();
param_.bias = op_desc.GetAttr("bias").get<float>();
param_.bias_after_scale = op_desc.GetAttr("bias_after_scale").get<bool>();
param_.scale = GetAttr<float>(op_desc.GetAttr("scale"));
param_.bias = GetAttr<float>(op_desc.GetAttr("bias"));
param_.bias_after_scale =
GetAttr<bool>(op_desc.GetAttr("bias_after_scale"));
return true;
}
......
if (NOT LITE_WITH_X86)
return()
endif()
cc_library(target_wrapper_x86 SRCS target_wrapper.cc)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册