提交 7446e257 编写于 作者: D DannyIsFunny

Merge remote-tracking branch 'origin' into test_result

...@@ -54,7 +54,8 @@ enum class TargetType : int { ...@@ -54,7 +54,8 @@ enum class TargetType : int {
kXPU = 9, kXPU = 9,
kBM = 10, kBM = 10,
kAny = 6, // any target kAny = 6, // any target
NUM = 11, // number of fields. kMLU = 11,
NUM = 12, // number of fields.
}; };
enum class PrecisionType : int { enum class PrecisionType : int {
kUnk = 0, kUnk = 0,
...@@ -98,7 +99,8 @@ enum class ActivationType : int { ...@@ -98,7 +99,8 @@ enum class ActivationType : int {
kTanh = 6, kTanh = 6,
kSwish = 7, kSwish = 7,
kExp = 8, kExp = 8,
NUM = 9, kAbs = 9,
NUM = 10,
}; };
static size_t PrecisionTypeLength(PrecisionType type) { static size_t PrecisionTypeLength(PrecisionType type) {
......
...@@ -29,6 +29,7 @@ enum class BinaryOperation { ...@@ -29,6 +29,7 @@ enum class BinaryOperation {
kADD = 0, kADD = 0,
kMUL = 1, kMUL = 1,
kDIV = 2, kDIV = 2,
kSUB = 3,
}; };
template <typename T> template <typename T>
...@@ -41,6 +42,7 @@ __device__ __forceinline__ float binary_calc(float x, ...@@ -41,6 +42,7 @@ __device__ __forceinline__ float binary_calc(float x,
if (type == BinaryOperation::kADD) return x + y; if (type == BinaryOperation::kADD) return x + y;
if (type == BinaryOperation::kMUL) return x * y; if (type == BinaryOperation::kMUL) return x * y;
if (type == BinaryOperation::kDIV) return x / y; if (type == BinaryOperation::kDIV) return x / y;
if (type == BinaryOperation::kSUB) return x - y;
} }
template <typename T> template <typename T>
......
...@@ -52,6 +52,7 @@ using XPUContext = Context<TargetType::kXPU>; ...@@ -52,6 +52,7 @@ using XPUContext = Context<TargetType::kXPU>;
using OpenCLContext = Context<TargetType::kOpenCL>; using OpenCLContext = Context<TargetType::kOpenCL>;
using FPGAContext = Context<TargetType::kFPGA>; using FPGAContext = Context<TargetType::kFPGA>;
using BMContext = Context<TargetType::kBM>; using BMContext = Context<TargetType::kBM>;
using MLUContext = Context<TargetType::kMLU>;
template <> template <>
class Context<TargetType::kHost> { class Context<TargetType::kHost> {
......
...@@ -22,6 +22,61 @@ ...@@ -22,6 +22,61 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
bool OpLite::InferShape() {
// if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_
// InferShapeByMemoryInternal will be applied.
if (param_.input_tensor_ptrs() && param_.output_tensor_ptrs()) {
return this->InferShapeWithCache();
} else {
// otherwise, InferShapeImpl is applied directly.
return this->InferShapeImpl();
}
}
bool OpLite::InferShapeWithCache() {
// 1. Get vector of current input tensors
auto *current_inputs = param_.input_tensor_ptrs();
// 2. Get hash value of current inputs shape and lod
size_t new_hash = 0;
for (auto iter = current_inputs->begin(); iter != current_inputs->end();
iter++) {
// combined dims value into new_hash value.
auto &element_dims = (*iter)->dims();
for (int i = 0; i < element_dims.size(); i++) {
new_hash =
lite::hash_combine(new_hash, static_cast<int>(element_dims[i]));
}
// combine lod value into new_hash valud.
auto &emement_lods = (*iter)->lod();
for (auto lod_iter = emement_lods.begin(); lod_iter != emement_lods.end();
lod_iter++) {
for (int i = 0; i < lod_iter->size(); i++) {
new_hash =
lite::hash_combine(new_hash, static_cast<int>(lod_iter->at(i)));
}
}
}
// 3. infer shapes of output tensors
if (new_hash == io_shape_lod_hash_ && new_hash != 0) {
// if current hash value is consistent with io_shape_lod_hash_,
// previous outputs shape and lod are reused.
auto *current_outputs = param_.output_tensor_ptrs();
for (int i = 0; i < current_outputs->size(); i++) {
current_outputs->at(i)->Resize(last_output_shapes[i]);
current_outputs->at(i)->set_lod(last_output_lods[i]);
}
} else {
// otherwise, current hash value is changed, InferShapeImpl will apply.
io_shape_lod_hash_ = new_hash;
this->InferShapeImpl();
auto *current_outputs = param_.output_tensor_ptrs();
for (int i = 0; i < current_outputs->size(); i++) {
last_output_shapes[i] = current_outputs->at(i)->dims();
last_output_lods[i] = current_outputs->at(i)->lod();
}
}
return true;
}
std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels( std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
const std::vector<Place> &places, const std::string &kernel_type) { const std::vector<Place> &places, const std::string &kernel_type) {
std::vector<std::unique_ptr<KernelBase>> kernels; std::vector<std::unique_ptr<KernelBase>> kernels;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <functional>
#include <list> #include <list>
#include <map> #include <map>
#include <memory> #include <memory>
...@@ -24,6 +25,7 @@ ...@@ -24,6 +25,7 @@
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/scope.h" #include "lite/core/scope.h"
#include "lite/model_parser/cpp/op_desc.h" #include "lite/model_parser/cpp/op_desc.h"
#include "lite/operators/op_params.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -64,8 +66,8 @@ class OpLite : public Registry { ...@@ -64,8 +66,8 @@ class OpLite : public Registry {
// Check the shape. // Check the shape.
virtual bool CheckShape() const { return true; } virtual bool CheckShape() const { return true; }
// Inference the outputs' shape. // Inference the outputs' shape.
virtual bool InferShape() const { return true; } virtual bool InferShapeImpl() const { return true; }
virtual bool SmartInferShape() { return this->InferShape(); } virtual bool InferShape();
// Run this operator. // Run this operator.
virtual bool Run(); virtual bool Run();
// Indicate whether the Op runs only once or not // Indicate whether the Op runs only once or not
...@@ -151,10 +153,16 @@ class OpLite : public Registry { ...@@ -151,10 +153,16 @@ class OpLite : public Registry {
std::vector<Place> valid_places_; std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::unique_ptr<OpInfo> op_info_; std::unique_ptr<OpInfo> op_info_;
std::vector<DDimLite> last_input_shapes;
std::vector<DDimLite> last_output_shapes; std::vector<DDimLite> last_output_shapes{};
std::vector<std::vector<std::vector<uint64_t>>> last_output_lods; std::vector<std::vector<std::vector<uint64_t>>> last_output_lods{};
std::vector<std::vector<std::vector<uint64_t>>> last_input_lods; size_t io_shape_lod_hash_{};
mutable operators::ParamBase param_;
private:
// Infer Shape according to memory, if current input shapes are consistent
// with that of previous inputs, output shapes of last time will be reused.
bool InferShapeWithCache();
}; };
/* /*
......
...@@ -107,6 +107,9 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create( ...@@ -107,6 +107,9 @@ std::list<std::unique_ptr<KernelBase>> KernelRegistry::Create(
case TARGET(kBM): { case TARGET(kBM): {
CREATE_KERNEL(kBM); CREATE_KERNEL(kBM);
} break; } break;
case TARGET(kMLU): {
CREATE_KERNEL(kMLU);
} break;
default: default:
CHECK(false) << "not supported kernel target " << TargetToStr(target); CHECK(false) << "not supported kernel target " << TargetToStr(target);
} }
...@@ -139,6 +142,15 @@ KernelRegistry::KernelRegistry() ...@@ -139,6 +142,15 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kCUDA, kInt64, kNCHW); INIT_FOR(kCUDA, kInt64, kNCHW);
INIT_FOR(kCUDA, kInt64, kNHWC); INIT_FOR(kCUDA, kInt64, kNHWC);
INIT_FOR(kMLU, kFloat, kNHWC);
INIT_FOR(kMLU, kFloat, kNCHW);
INIT_FOR(kMLU, kFP16, kNHWC);
INIT_FOR(kMLU, kFP16, kNCHW);
INIT_FOR(kMLU, kInt8, kNHWC);
INIT_FOR(kMLU, kInt8, kNCHW);
INIT_FOR(kMLU, kInt16, kNHWC);
INIT_FOR(kMLU, kInt16, kNCHW);
INIT_FOR(kHost, kFloat, kNCHW); INIT_FOR(kHost, kFloat, kNCHW);
INIT_FOR(kHost, kAny, kNCHW); INIT_FOR(kHost, kAny, kNCHW);
INIT_FOR(kHost, kFloat, kNHWC); INIT_FOR(kHost, kFloat, kNHWC);
......
...@@ -268,7 +268,32 @@ class KernelRegistry final { ...@@ -268,7 +268,32 @@ class KernelRegistry final {
DATALAYOUT(kAny)> *, // DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kFPGA), KernelRegistryForTarget<TARGET(kFPGA),
PRECISION(kAny), PRECISION(kAny),
DATALAYOUT(kAny)> * // DATALAYOUT(kAny)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kFloat),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kFP16),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kFP16),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt8),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt16),
DATALAYOUT(kNHWC)> *, //
KernelRegistryForTarget<TARGET(kMLU),
PRECISION(kInt16),
DATALAYOUT(kNCHW)> * //
>; >;
KernelRegistry(); KernelRegistry();
......
...@@ -286,8 +286,7 @@ void Instruction::Run() { ...@@ -286,8 +286,7 @@ void Instruction::Run() {
return; return;
} }
// op_->InferShape(); op_->InferShape();
op_->SmartInferShape();
kernel_->Launch(); kernel_->Launch();
has_run_ = true; has_run_ = true;
} }
......
...@@ -59,7 +59,8 @@ void SequencePoolCompute::Run() { ...@@ -59,7 +59,8 @@ void SequencePoolCompute::Run() {
for (int i = 0; i <= batch_size; i++) { for (int i = 0; i <= batch_size; i++) {
offset_new[i] = i; offset_new[i] = i;
} }
(output->mutable_lod())->push_back(offset_new); output->mutable_lod()->clear();
output->mutable_lod()->push_back(offset_new);
} }
} // namespace arm } // namespace arm
......
...@@ -8,6 +8,8 @@ add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_de ...@@ -8,6 +8,8 @@ add_kernel(mul_compute_cuda CUDA basic SRCS mul_compute.cc DEPS ${lite_kernel_de
add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps}) add_kernel(search_group_padding_compute_cuda CUDA basic SRCS search_group_padding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps}) add_kernel(io_copy_compute_cuda CUDA basic SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps}) add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(abs_compute_cuda CUDA basic SRCS abs_compute.cu DEPS ${lite_kernel_deps})
add_kernel(tanh_compute_cuda CUDA basic SRCS tanh_compute.cu DEPS ${lite_kernel_deps})
add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps}) add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps}) add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS ${lite_kernel_deps})
...@@ -45,6 +47,8 @@ lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_ ...@@ -45,6 +47,8 @@ lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_
#nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda) #nv_test(conv2d_cuda_test SRCS conv_compute_test.cc DEPS conv2d_cuda)
nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda) nv_test(nearest_interp_compute_cuda_test SRCS nearest_interp_compute_test.cc DEPS nearest_interp_compute_cuda)
nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda) nv_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
nv_test(abs_compute_cuda_test SRCS abs_compute_test.cc DEPS abs_compute_cuda)
nv_test(tanh_compute_cuda_test SRCS tanh_compute_test.cc DEPS tanh_compute_cuda)
nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda) nv_test(relu_compute_cuda_test SRCS relu_compute_test.cc DEPS relu_compute_cuda)
nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda) nv_test(yolo_box_compute_cuda_test SRCS yolo_box_compute_test.cc DEPS yolo_box_compute_cuda)
nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda) nv_test(transpose_compute_cuda_test SRCS transpose_compute_test.cc DEPS transpose_compute_cuda)
...@@ -61,7 +65,7 @@ nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc ...@@ -61,7 +65,7 @@ nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc
#nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda) #nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
#nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda) #nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda)
nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda) nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda)
#nv_test(search_fc_cuda_test SRCS search_fc_compute_test.cc DEPS search_fc_compute_cuda sequence_topk_avg_pooling_compute_cuda) #nv_test(search_fc_cuda_test SRCS search_fc_compute_test.cc DEPS search_fc_compute_cuda)
#nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda) #nv_test(var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda)
if(LITE_BUILD_EXTRA) if(LITE_BUILD_EXTRA)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/abs_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
__global__ void AbsKernel(const int num, const T* input, T* output);
template <>
__global__ void AbsKernel<float>(const int num,
const float* input,
float* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
output[index] = fabsf(input[index]);
}
}
template <>
__global__ void AbsKernel<double>(const int num,
const double* input,
double* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
output[index] = fabs(input[index]);
}
}
void AbsCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
int num = static_cast<int>(param.X->numel());
auto input = param.X->data<float>();
auto output = param.Out->mutable_data<float>(TARGET(kCUDA));
const int threads = 512;
const int blocks = (num + threads - 1) / threads;
AbsKernel<float><<<blocks, threads, 0, stream>>>(num, input, output);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
abs, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::AbsCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class AbsCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~AbsCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/abs_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
TEST(abs, normal) {
AbsCompute abs_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ActivationParam param;
Tensor x, y, x_cpu, y_cpu;
int h = 3, w = 3;
y.Resize({h, w});
x_cpu.Resize({h, w});
y_cpu.Resize({h, w});
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* y_cpu_data = y_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = i - 1.5;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
param.X = &x;
param.Out = &y;
abs_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
abs_kernel.SetContext(std::move(ctx));
abs_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
for (int i = 0; i < y.numel(); i++) {
EXPECT_NEAR(y_cpu_data[i], std::fabs(x_cpu_data[i]), 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -152,6 +152,18 @@ void ElementwiseAddComputeNHWC::Run() { ...@@ -152,6 +152,18 @@ void ElementwiseAddComputeNHWC::Run() {
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
} }
void ElementwiseSubCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kSUB, false)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseSubComputeNHWC::Run() {
ELEMENTWISE_COMPUTE_NHWC(lite::cuda::math::BinaryOperation::kSUB, false)
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
void ElementwiseMulCompute::Run() { void ElementwiseMulCompute::Run() {
ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, false) ELEMENTWISE_COMPUTE(lite::cuda::math::BinaryOperation::kMUL, false)
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
...@@ -204,6 +216,17 @@ REGISTER_LITE_KERNEL(elementwise_add, ...@@ -204,6 +216,17 @@ REGISTER_LITE_KERNEL(elementwise_add,
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(elementwise_sub,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::ElementwiseSubCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_add, REGISTER_LITE_KERNEL(elementwise_add,
kCUDA, kCUDA,
kFloat, kFloat,
...@@ -224,6 +247,26 @@ REGISTER_LITE_KERNEL(elementwise_add, ...@@ -224,6 +247,26 @@ REGISTER_LITE_KERNEL(elementwise_add,
DATALAYOUT(kNHWC))}) DATALAYOUT(kNHWC))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(elementwise_sub,
kCUDA,
kFloat,
kNHWC,
paddle::lite::kernels::cuda::ElementwiseSubComputeNHWC,
nhwc_format)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindInput("Y",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNHWC))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_mul, REGISTER_LITE_KERNEL(elementwise_mul,
kCUDA, kCUDA,
kFloat, kFloat,
......
...@@ -38,6 +38,24 @@ class ElementwiseAddComputeNHWC ...@@ -38,6 +38,24 @@ class ElementwiseAddComputeNHWC
virtual ~ElementwiseAddComputeNHWC() = default; virtual ~ElementwiseAddComputeNHWC() = default;
}; };
class ElementwiseSubCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override;
virtual ~ElementwiseSubCompute() = default;
};
class ElementwiseSubComputeNHWC
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNHWC)> {
public:
using param_t = operators::ElementwiseParam;
void Run() override;
virtual ~ElementwiseSubComputeNHWC() = default;
};
class ElementwiseMulCompute class ElementwiseMulCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> { : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public: public:
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/tanh_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
__global__ void TanhKernel(const int num, const T* input, T* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
output[index] = tanh(input[index]);
}
}
void TanhCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
int num = static_cast<int>(param.X->numel());
auto input = param.X->data<float>();
auto output = param.Out->mutable_data<float>(TARGET(kCUDA));
const int threads = 512;
const int blocks = (num + threads - 1) / threads;
TanhKernel<float><<<blocks, threads, 0, stream>>>(num, input, output);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
tanh, kCUDA, kFloat, kNCHW, paddle::lite::kernels::cuda::TanhCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class TanhCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~TanhCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/tanh_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
TEST(tanh, fp32) {
TanhCompute tanh_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ActivationParam param;
Tensor x, y, x_cpu, y_cpu;
int h = 3, w = 3;
y.Resize({h, w});
x_cpu.Resize({h, w});
y_cpu.Resize({h, w});
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* y_cpu_data = y_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = i - 1.5;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
param.X = &x;
param.Out = &y;
tanh_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
tanh_kernel.SetContext(std::move(ctx));
tanh_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
for (int i = 0; i < y.numel(); i++) {
EXPECT_NEAR(y_cpu_data[i], tanh(x_cpu_data[i]), 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -220,6 +220,8 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -220,6 +220,8 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
act_op->set_attr_mode(CvtActMode(act_type)); act_op->set_attr_mode(CvtActMode(act_type));
if (act_type == "leaky_relu") { if (act_type == "leaky_relu") {
act_op->set_attr_negative_slope(leaky_relu_alpha); act_op->set_attr_negative_slope(leaky_relu_alpha);
} else if (act_type == "relu6") {
act_op->set_attr_coef(6.f);
} }
} }
......
...@@ -18,6 +18,7 @@ USE_SUBGRAPH_BRIDGE(sigmoid, kNPU); ...@@ -18,6 +18,7 @@ USE_SUBGRAPH_BRIDGE(sigmoid, kNPU);
USE_SUBGRAPH_BRIDGE(relu, kNPU); USE_SUBGRAPH_BRIDGE(relu, kNPU);
USE_SUBGRAPH_BRIDGE(tanh, kNPU); USE_SUBGRAPH_BRIDGE(tanh, kNPU);
USE_SUBGRAPH_BRIDGE(relu_clipped, kNPU); USE_SUBGRAPH_BRIDGE(relu_clipped, kNPU);
USE_SUBGRAPH_BRIDGE(relu6, kNPU);
USE_SUBGRAPH_BRIDGE(leaky_relu, kNPU); USE_SUBGRAPH_BRIDGE(leaky_relu, kNPU);
USE_SUBGRAPH_BRIDGE(softsign, kNPU); USE_SUBGRAPH_BRIDGE(softsign, kNPU);
USE_SUBGRAPH_BRIDGE(hard_sigmoid, kNPU); USE_SUBGRAPH_BRIDGE(hard_sigmoid, kNPU);
......
...@@ -99,10 +99,8 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -99,10 +99,8 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
ksize); ksize);
// ceil mode // ceil mode
int ceil_mode = 0; bool ceil_mode =
if (op_info->HasAttr("ceil_mode")) { op_info->HasAttr("ceil_mode") && op_info->GetAttr<bool>("ceil_mode");
ceil_mode = op_info->GetAttr<bool>("ceil_mode") ? 1 : 0;
}
// Pooling node // Pooling node
auto pool_node = graph->Add<ge::op::Pooling>(out_name); auto pool_node = graph->Add<ge::op::Pooling>(out_name);
...@@ -112,12 +110,14 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) { ...@@ -112,12 +110,14 @@ int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
pool_op->set_attr_pad_mode(pad_mode); pool_op->set_attr_pad_mode(pad_mode);
pool_op->set_attr_global_pooling(global_pooling); pool_op->set_attr_global_pooling(global_pooling);
pool_op->set_attr_window(ge::AttrValue::LIST_INT(ksize.begin(), ksize.end())); pool_op->set_attr_window(ge::AttrValue::LIST_INT(ksize.begin(), ksize.end()));
pool_op->set_attr_pad(ge::AttrValue::LIST_INT{ pool_op->set_attr_pad(
paddings[0], paddings[1], paddings[2], paddings[3]}); ge::AttrValue::LIST_INT(paddings.begin(), paddings.end()));
pool_op->set_attr_stride( pool_op->set_attr_stride(
ge::AttrValue::LIST_INT(strides.begin(), strides.end())); ge::AttrValue::LIST_INT(strides.begin(), strides.end()));
pool_op->set_attr_ceil_mode(ceil_mode); if (ceil_mode) {
// pool_op->set_attr_data_mode(data_mode); pool_op->set_attr_ceil_mode(1);
pool_op->set_attr_data_mode(0);
}
return REBUILD_WHEN_SHAPE_CHANGED; return REBUILD_WHEN_SHAPE_CHANGED;
} }
......
...@@ -35,7 +35,7 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -35,7 +35,7 @@ int SubgraphEngine::BuildDeviceProgram() {
subgraph::npu::Graph graph; subgraph::npu::Graph graph;
const auto& bridges = subgraph::Registry::Instance(); const auto& bridges = subgraph::Registry::Instance();
for (auto& inst : origin_program_) { for (auto& inst : origin_program_) {
auto op = inst.op(); auto op = const_cast<OpLite*>(inst.op());
CHECK(op); CHECK(op);
op->CheckShape(); op->CheckShape();
op->InferShape(); op->InferShape();
...@@ -44,10 +44,8 @@ int SubgraphEngine::BuildDeviceProgram() { ...@@ -44,10 +44,8 @@ int SubgraphEngine::BuildDeviceProgram() {
return subgraph::FAILED; return subgraph::FAILED;
} }
auto kernel = inst.kernel(); auto kernel = inst.kernel();
status |= status |= bridges.Select(op_type, TARGET(kNPU))(
bridges.Select(op_type, TARGET(kNPU))(reinterpret_cast<void*>(&graph), reinterpret_cast<void*>(&graph), op, const_cast<KernelBase*>(kernel));
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) { if (subgraph::CHECK_FAILED(status)) {
return subgraph::FAILED; return subgraph::FAILED;
} }
......
...@@ -25,7 +25,7 @@ bool ActivationGradOp::CheckShape() const { ...@@ -25,7 +25,7 @@ bool ActivationGradOp::CheckShape() const {
return true; return true;
} }
bool ActivationGradOp::InferShape() const { bool ActivationGradOp::InferShapeImpl() const {
param_.X_grad->Resize(param_.Out_grad->dims()); param_.X_grad->Resize(param_.Out_grad->dims());
return true; return true;
} }
......
...@@ -26,7 +26,7 @@ class ActivationGradOp : public OpLite { ...@@ -26,7 +26,7 @@ class ActivationGradOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
...@@ -25,7 +25,7 @@ bool ActivationOp::CheckShape() const { ...@@ -25,7 +25,7 @@ bool ActivationOp::CheckShape() const {
return true; return true;
} }
bool ActivationOp::InferShape() const { bool ActivationOp::InferShapeImpl() const {
param_.Out->Resize(param_.X->dims()); param_.Out->Resize(param_.X->dims());
auto out_lod = param_.Out->mutable_lod(); auto out_lod = param_.Out->mutable_lod();
*out_lod = param_.X->lod(); *out_lod = param_.X->lod();
...@@ -71,6 +71,9 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { ...@@ -71,6 +71,9 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
} else if (opdesc.Type() == "exp") { } else if (opdesc.Type() == "exp") {
// exp // exp
param_.active_type = lite_api::ActivationType::kExp; param_.active_type = lite_api::ActivationType::kExp;
} else if (opdesc.Type() == "abs") {
// abs
param_.active_type = lite_api::ActivationType::kAbs;
} }
VLOG(4) << "opdesc.Type():" << opdesc.Type(); VLOG(4) << "opdesc.Type():" << opdesc.Type();
...@@ -92,6 +95,7 @@ REGISTER_LITE_OP(swish, paddle::lite::operators::ActivationOp); ...@@ -92,6 +95,7 @@ REGISTER_LITE_OP(swish, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(abs, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(sqrt, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(sqrt, paddle::lite::operators::ActivationOp);
......
...@@ -26,7 +26,7 @@ class ActivationOp : public OpLite { ...@@ -26,7 +26,7 @@ class ActivationOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
...@@ -44,7 +44,7 @@ bool AffineChannelOpLite::CheckShape() const { ...@@ -44,7 +44,7 @@ bool AffineChannelOpLite::CheckShape() const {
return true; return true;
} }
bool AffineChannelOpLite::InferShape() const { bool AffineChannelOpLite::InferShapeImpl() const {
const auto x_dims = param_.X->dims(); const auto x_dims = param_.X->dims();
param_.Out->Resize(x_dims); param_.Out->Resize(x_dims);
return true; return true;
......
...@@ -31,7 +31,7 @@ class AffineChannelOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class AffineChannelOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -31,7 +31,7 @@ bool AnchorGeneratorOpLite::CheckShape() const { ...@@ -31,7 +31,7 @@ bool AnchorGeneratorOpLite::CheckShape() const {
return true; return true;
} }
bool AnchorGeneratorOpLite::InferShape() const { bool AnchorGeneratorOpLite::InferShapeImpl() const {
auto input_dims = param_.Input->dims(); auto input_dims = param_.Input->dims();
size_t num_anchors = param_.aspect_ratios.size() * param_.anchor_sizes.size(); size_t num_anchors = param_.aspect_ratios.size() * param_.anchor_sizes.size();
std::vector<int64_t> output_shape( std::vector<int64_t> output_shape(
......
...@@ -32,7 +32,7 @@ class AnchorGeneratorOpLite : public OpLite { ...@@ -32,7 +32,7 @@ class AnchorGeneratorOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -29,7 +29,7 @@ bool ArgmaxOpLite::CheckShape() const { ...@@ -29,7 +29,7 @@ bool ArgmaxOpLite::CheckShape() const {
return true; return true;
} }
bool ArgmaxOpLite::InferShape() const { bool ArgmaxOpLite::InferShapeImpl() const {
auto x_dims = param_.X->dims(); auto x_dims = param_.X->dims();
int x_rank = x_dims.size(); int x_rank = x_dims.size();
int axis = param_.Axis; int axis = param_.Axis;
......
...@@ -31,7 +31,7 @@ class ArgmaxOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class ArgmaxOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool AssignOpLite::CheckShape() const { ...@@ -26,7 +26,7 @@ bool AssignOpLite::CheckShape() const {
return true; return true;
} }
bool AssignOpLite::InferShape() const { bool AssignOpLite::InferShapeImpl() const {
lite::DDim input_dims; lite::DDim input_dims;
input_dims = param_.X->dims(); input_dims = param_.X->dims();
param_.Out->Resize(lite::DDim(input_dims)); param_.Out->Resize(lite::DDim(input_dims));
......
...@@ -30,7 +30,7 @@ class AssignOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class AssignOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -35,7 +35,7 @@ bool AssignValueOpLite::CheckShape() const { ...@@ -35,7 +35,7 @@ bool AssignValueOpLite::CheckShape() const {
return true; return true;
} }
bool AssignValueOpLite::InferShape() const { bool AssignValueOpLite::InferShapeImpl() const {
std::vector<int> shape = param_.shape; std::vector<int> shape = param_.shape;
std::vector<int64_t> out_shape; std::vector<int64_t> out_shape;
for (size_t i = 0; i < shape.size(); i++) out_shape.push_back(shape[i]); for (size_t i = 0; i < shape.size(); i++) out_shape.push_back(shape[i]);
......
...@@ -31,7 +31,7 @@ class AssignValueOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class AssignValueOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -28,7 +28,7 @@ bool AttentionPaddingMaskOp::CheckShape() const { ...@@ -28,7 +28,7 @@ bool AttentionPaddingMaskOp::CheckShape() const {
return true; return true;
} }
bool AttentionPaddingMaskOp::InferShape() const { bool AttentionPaddingMaskOp::InferShapeImpl() const {
auto src_len = param_.X->lod()[0][1]; auto src_len = param_.X->lod()[0][1];
CHECK_EQ(src_len, param_.X->dims()[1]) CHECK_EQ(src_len, param_.X->dims()[1])
<< "Mismatch source length, expect: " << src_len << "Mismatch source length, expect: " << src_len
......
...@@ -29,7 +29,7 @@ class AttentionPaddingMaskOp : public OpLite { ...@@ -29,7 +29,7 @@ class AttentionPaddingMaskOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -34,7 +34,7 @@ bool AxpyOpLite::CheckShape() const { ...@@ -34,7 +34,7 @@ bool AxpyOpLite::CheckShape() const {
return true; return true;
} }
bool AxpyOpLite::InferShape() const { bool AxpyOpLite::InferShapeImpl() const {
auto dims = param_.Bias->dims(); auto dims = param_.Bias->dims();
// Set output dims // Set output dims
......
...@@ -31,7 +31,7 @@ class AxpyOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class AxpyOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -46,7 +46,7 @@ bool BatchNormOp::CheckShape() const { ...@@ -46,7 +46,7 @@ bool BatchNormOp::CheckShape() const {
return true; return true;
} }
bool BatchNormOp::InferShape() const { bool BatchNormOp::InferShapeImpl() const {
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
int64_t channel_size = 0; int64_t channel_size = 0;
switch (param_.data_layout) { switch (param_.data_layout) {
......
...@@ -30,7 +30,7 @@ class BatchNormOp : public OpLite { ...@@ -30,7 +30,7 @@ class BatchNormOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -28,7 +28,7 @@ bool BeamSearchDecodeOpLite::CheckShape() const { ...@@ -28,7 +28,7 @@ bool BeamSearchDecodeOpLite::CheckShape() const {
return true; return true;
} }
bool BeamSearchDecodeOpLite::InferShape() const { return true; } bool BeamSearchDecodeOpLite::InferShapeImpl() const { return true; }
bool BeamSearchDecodeOpLite::AttachImpl(const cpp::OpDesc &op_desc, bool BeamSearchDecodeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) { lite::Scope *scope) {
......
...@@ -31,7 +31,7 @@ class BeamSearchDecodeOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class BeamSearchDecodeOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -30,7 +30,7 @@ bool BeamSearchOp::CheckShape() const { ...@@ -30,7 +30,7 @@ bool BeamSearchOp::CheckShape() const {
return true; return true;
} }
bool BeamSearchOp::InferShape() const { return true; } bool BeamSearchOp::InferShapeImpl() const { return true; }
bool BeamSearchOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { bool BeamSearchOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.pre_ids = scope->FindTensor(opdesc.Input("pre_ids").front()); param_.pre_ids = scope->FindTensor(opdesc.Input("pre_ids").front());
......
...@@ -30,7 +30,7 @@ class BeamSearchOp : public OpLite { ...@@ -30,7 +30,7 @@ class BeamSearchOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -35,7 +35,7 @@ bool BoxClipOpLite::CheckShape() const { ...@@ -35,7 +35,7 @@ bool BoxClipOpLite::CheckShape() const {
return true; return true;
} }
bool BoxClipOpLite::InferShape() const { bool BoxClipOpLite::InferShapeImpl() const {
auto* input = param_.Input; auto* input = param_.Input;
auto* output = param_.Output; auto* output = param_.Output;
output->Resize(input->dims()); output->Resize(input->dims());
......
...@@ -31,7 +31,7 @@ class BoxClipOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class BoxClipOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -35,7 +35,7 @@ bool BoxCoderOpLite::CheckShape() const { ...@@ -35,7 +35,7 @@ bool BoxCoderOpLite::CheckShape() const {
return true; return true;
} }
bool BoxCoderOpLite::InferShape() const { bool BoxCoderOpLite::InferShapeImpl() const {
auto prior_box_dims = param_.prior_box->dims(); auto prior_box_dims = param_.prior_box->dims();
auto target_box_dims = param_.target_box->dims(); auto target_box_dims = param_.target_box->dims();
std::string code_type = param_.code_type; std::string code_type = param_.code_type;
......
...@@ -29,7 +29,7 @@ class BoxCoderOpLite : public OpLite { ...@@ -29,7 +29,7 @@ class BoxCoderOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -24,7 +24,7 @@ bool CalibOpLite::CheckShape() const { ...@@ -24,7 +24,7 @@ bool CalibOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.output);
return true; return true;
} }
bool CalibOpLite::InferShape() const { bool CalibOpLite::InferShapeImpl() const {
param_.output->Resize(param_.input->dims()); param_.output->Resize(param_.input->dims());
return true; return true;
} }
......
...@@ -42,7 +42,7 @@ class CalibOpLite : public OpLite { ...@@ -42,7 +42,7 @@ class CalibOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope); bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope);
......
...@@ -25,7 +25,7 @@ bool CastOp::CheckShape() const { ...@@ -25,7 +25,7 @@ bool CastOp::CheckShape() const {
return true; return true;
} }
bool CastOp::InferShape() const { bool CastOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
auto out_dims = param_.X->dims(); auto out_dims = param_.X->dims();
......
...@@ -30,7 +30,7 @@ class CastOp : public OpLite { ...@@ -30,7 +30,7 @@ class CastOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -43,7 +43,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const { ...@@ -43,7 +43,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const {
return true; return true;
} }
bool CollectFpnProposalsOpLite::InferShape() const { bool CollectFpnProposalsOpLite::InferShapeImpl() const {
param_.fpn_rois->Resize({param_.post_nms_topN, 4}); param_.fpn_rois->Resize({param_.post_nms_topN, 4});
return true; return true;
......
...@@ -32,7 +32,7 @@ class CollectFpnProposalsOpLite : public OpLite { ...@@ -32,7 +32,7 @@ class CollectFpnProposalsOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool CompareOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool CompareOp::CheckShape() const {
return true; return true;
} }
bool CompareOp::InferShape() const { bool CompareOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
// TODO(Superjomn) Enable data sharing. // TODO(Superjomn) Enable data sharing.
auto input_dims = param_.X->dims(); auto input_dims = param_.X->dims();
......
...@@ -30,7 +30,7 @@ class CompareOp : public OpLite { ...@@ -30,7 +30,7 @@ class CompareOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool ConcatOpLite::CheckShape() const { ...@@ -26,7 +26,7 @@ bool ConcatOpLite::CheckShape() const {
return true; return true;
} }
bool ConcatOpLite::InferShape() const { bool ConcatOpLite::InferShapeImpl() const {
const std::vector<Tensor *> &inputs = param_.x; const std::vector<Tensor *> &inputs = param_.x;
const size_t n = inputs.size(); const size_t n = inputs.size();
CHECK_GT_OR_FALSE(n, 0); CHECK_GT_OR_FALSE(n, 0);
......
...@@ -30,7 +30,7 @@ class ConcatOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class ConcatOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -27,7 +27,7 @@ bool ConditionalBlockOpLite::CheckShape() const { ...@@ -27,7 +27,7 @@ bool ConditionalBlockOpLite::CheckShape() const {
return true; return true;
} }
bool ConditionalBlockOpLite::InferShape() const { return true; } bool ConditionalBlockOpLite::InferShapeImpl() const { return true; }
bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc, bool ConditionalBlockOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) { lite::Scope *scope) {
......
...@@ -31,7 +31,7 @@ class ConditionalBlockOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class ConditionalBlockOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -80,35 +80,7 @@ void UpdatePaddingAndDilation(std::vector<int>* paddings, ...@@ -80,35 +80,7 @@ void UpdatePaddingAndDilation(std::vector<int>* paddings,
} }
} }
bool ConvOpLite::SmartInferShape() { bool ConvOpLite::InferShapeImpl() const {
if (!last_input_shapes.empty()) {
if (last_input_shapes[0] == param_.x->dims() &&
last_input_lods[0] == param_.x->lod()) {
param_.output->Resize(last_output_shapes[0]);
param_.output->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.x->dims());
last_input_lods.push_back(param_.x->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
last_output_lods.push_back(param_.output->lod());
return true;
}
bool ConvOpLite::InferShape() const {
const auto in_dims = param_.x->dims(); const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims(); const auto filter_dims = param_.filter->dims();
......
...@@ -34,9 +34,7 @@ class ConvOpLite : public OpLite { ...@@ -34,9 +34,7 @@ class ConvOpLite : public OpLite {
explicit ConvOpLite(const std::string& type) : OpLite(type) {} explicit ConvOpLite(const std::string& type) : OpLite(type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShapeImpl() const override;
bool InferShape() const override;
bool SmartInferShape() override;
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
......
...@@ -52,7 +52,7 @@ inline int ConvTransposeOutputSize(int input_size, ...@@ -52,7 +52,7 @@ inline int ConvTransposeOutputSize(int input_size,
return output_size; return output_size;
} }
bool ConvTransposeOpLite::InferShape() const { bool ConvTransposeOpLite::InferShapeImpl() const {
const auto in_dims = param_.x->dims(); const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims(); const auto filter_dims = param_.filter->dims();
......
...@@ -34,7 +34,7 @@ class ConvTransposeOpLite : public OpLite { ...@@ -34,7 +34,7 @@ class ConvTransposeOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -60,7 +60,7 @@ bool CrfDecodingOpLite::CheckShape() const { ...@@ -60,7 +60,7 @@ bool CrfDecodingOpLite::CheckShape() const {
return true; return true;
} }
bool CrfDecodingOpLite::InferShape() const { bool CrfDecodingOpLite::InferShapeImpl() const {
auto emission_dims = param_.emission->dims(); auto emission_dims = param_.emission->dims();
if (param_.length == nullptr) { if (param_.length == nullptr) {
param_.viterbi_path->Resize({emission_dims[0], 1}); param_.viterbi_path->Resize({emission_dims[0], 1});
......
...@@ -31,7 +31,7 @@ class CrfDecodingOpLite : public OpLite { ...@@ -31,7 +31,7 @@ class CrfDecodingOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool CropOpLite::CheckShape() const { ...@@ -26,7 +26,7 @@ bool CropOpLite::CheckShape() const {
return true; return true;
} }
bool CropOpLite::InferShape() const { bool CropOpLite::InferShapeImpl() const {
// nchw // nchw
auto x_dims = param_.X->dims(); auto x_dims = param_.X->dims();
lite::DDim output_shape(x_dims); lite::DDim output_shape(x_dims);
......
...@@ -30,7 +30,7 @@ class CropOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class CropOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -29,7 +29,7 @@ bool DecodeBboxesOpLite::CheckShape() const { ...@@ -29,7 +29,7 @@ bool DecodeBboxesOpLite::CheckShape() const {
return true; return true;
} }
bool DecodeBboxesOpLite::InferShape() const { bool DecodeBboxesOpLite::InferShapeImpl() const {
param_.bbox_data->Resize(param_.loc_data->dims()); param_.bbox_data->Resize(param_.loc_data->dims());
return true; return true;
} }
......
...@@ -29,7 +29,7 @@ class DecodeBboxesOpLite : public OpLite { ...@@ -29,7 +29,7 @@ class DecodeBboxesOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -27,7 +27,7 @@ bool DensityPriorBoxOpLite::CheckShape() const { ...@@ -27,7 +27,7 @@ bool DensityPriorBoxOpLite::CheckShape() const {
return true; return true;
} }
bool DensityPriorBoxOpLite::InferShape() const { return true; } bool DensityPriorBoxOpLite::InferShapeImpl() const { return true; }
bool DensityPriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc, bool DensityPriorBoxOpLite::AttachImpl(const cpp::OpDesc& opdesc,
lite::Scope* scope) { lite::Scope* scope) {
......
...@@ -30,7 +30,7 @@ class DensityPriorBoxOpLite : public OpLite { ...@@ -30,7 +30,7 @@ class DensityPriorBoxOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -32,7 +32,7 @@ bool DistributeFpnProposalsOpLite::CheckShape() const { ...@@ -32,7 +32,7 @@ bool DistributeFpnProposalsOpLite::CheckShape() const {
return true; return true;
} }
bool DistributeFpnProposalsOpLite::InferShape() const { bool DistributeFpnProposalsOpLite::InferShapeImpl() const {
int num_out_rois = param_.max_level - param_.min_level + 1; int num_out_rois = param_.max_level - param_.min_level + 1;
for (int i = 0; i < num_out_rois; i++) { for (int i = 0; i < num_out_rois; i++) {
param_.multi_fpn_rois[i]->Resize({-1, 4}); param_.multi_fpn_rois[i]->Resize({-1, 4});
......
...@@ -32,7 +32,7 @@ class DistributeFpnProposalsOpLite : public OpLite { ...@@ -32,7 +32,7 @@ class DistributeFpnProposalsOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -26,7 +26,7 @@ bool DropoutOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool DropoutOp::CheckShape() const {
return true; return true;
} }
bool DropoutOp::InferShape() const { bool DropoutOp::InferShapeImpl() const {
const auto x_dims = param_.x->dims(); const auto x_dims = param_.x->dims();
param_.output->Resize(x_dims); param_.output->Resize(x_dims);
if (param_.is_test == false) { if (param_.is_test == false) {
......
...@@ -28,7 +28,7 @@ class DropoutOp : public OpLite { ...@@ -28,7 +28,7 @@ class DropoutOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
......
...@@ -26,7 +26,7 @@ bool ElementwiseGradOp::CheckShape() const { ...@@ -26,7 +26,7 @@ bool ElementwiseGradOp::CheckShape() const {
return true; return true;
} }
bool ElementwiseGradOp::InferShape() const { bool ElementwiseGradOp::InferShapeImpl() const {
auto x_dim = param_.X->dims(); auto x_dim = param_.X->dims();
auto y_dim = param_.Y->dims(); auto y_dim = param_.Y->dims();
if (param_.XGrad) { if (param_.XGrad) {
......
...@@ -27,7 +27,7 @@ class ElementwiseGradOp : public OpLite { ...@@ -27,7 +27,7 @@ class ElementwiseGradOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
...@@ -26,39 +26,8 @@ bool ElementwiseOp::CheckShape() const { ...@@ -26,39 +26,8 @@ bool ElementwiseOp::CheckShape() const {
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
return true; return true;
} }
bool ElementwiseOp::SmartInferShape() {
if (!last_input_shapes.empty()) {
if (last_input_shapes[0] == param_.X->dims() &&
last_input_shapes[1] == param_.Y->dims() &&
last_input_lods[0] == param_.X->lod() &&
last_input_lods[1] == param_.Y->lod()) {
param_.Out->Resize(last_output_shapes[0]);
param_.Out->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.X->dims());
last_input_lods.push_back(param_.X->lod());
last_input_shapes.push_back(param_.Y->dims());
last_input_lods.push_back(param_.Y->lod());
if (!last_output_shapes.empty()) { bool ElementwiseOp::InferShapeImpl() const {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.Out->dims());
last_output_lods.push_back(param_.Out->lod());
return true;
}
bool ElementwiseOp::InferShape() const {
auto x_dim = param_.X->dims(); auto x_dim = param_.X->dims();
auto y_dim = param_.Y->dims(); auto y_dim = param_.Y->dims();
if (x_dim == y_dim) { if (x_dim == y_dim) {
...@@ -136,7 +105,7 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { ...@@ -136,7 +105,7 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
// return true; // return true;
//} //}
// bool ElementwiseGradExplicitOp::InferShape() const { // bool ElementwiseGradExplicitOp::InferShapeImpl() const {
// param_.X_grad->Resize(param_.Out_grad->dims()); // param_.X_grad->Resize(param_.Out_grad->dims());
// if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims()); // if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims());
// return true; // return true;
......
...@@ -27,8 +27,7 @@ class ElementwiseOp : public OpLite { ...@@ -27,8 +27,7 @@ class ElementwiseOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool SmartInferShape() override;
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
...@@ -48,7 +47,7 @@ class ElementwiseOp : public OpLite { ...@@ -48,7 +47,7 @@ class ElementwiseOp : public OpLite {
// bool CheckShape() const override; // bool CheckShape() const override;
// bool InferShape() const override; // bool InferShapeImpl() const override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override; // bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
......
...@@ -32,7 +32,7 @@ bool ExpandOpLite::CheckShape() const { ...@@ -32,7 +32,7 @@ bool ExpandOpLite::CheckShape() const {
return true; return true;
} }
bool ExpandOpLite::InferShape() const { bool ExpandOpLite::InferShapeImpl() const {
DDim out_dims(param_.X->dims()); DDim out_dims(param_.X->dims());
for (size_t i = 0; i < param_.expand_times.size(); ++i) { for (size_t i = 0; i < param_.expand_times.size(); ++i) {
out_dims[i] *= param_.expand_times[i]; out_dims[i] *= param_.expand_times[i];
......
...@@ -28,7 +28,7 @@ class ExpandOpLite : public OpLite { ...@@ -28,7 +28,7 @@ class ExpandOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -36,7 +36,7 @@ class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite { ...@@ -36,7 +36,7 @@ class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; } bool CheckShape() const override { return true; }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
......
...@@ -35,7 +35,7 @@ class FakeDequantizeMaxAbsOpLite : public OpLite { ...@@ -35,7 +35,7 @@ class FakeDequantizeMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; } bool CheckShape() const override { return true; }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
......
...@@ -36,7 +36,7 @@ class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite { ...@@ -36,7 +36,7 @@ class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; } bool CheckShape() const override { return true; }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
......
...@@ -36,7 +36,7 @@ class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite { ...@@ -36,7 +36,7 @@ class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; } bool CheckShape() const override { return true; }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
......
...@@ -36,7 +36,7 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite { ...@@ -36,7 +36,7 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite {
bool CheckShape() const override { return true; } bool CheckShape() const override { return true; }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front(); auto x = op_desc.Input("X").front();
......
...@@ -48,34 +48,7 @@ bool FcOpLite::CheckShape() const { ...@@ -48,34 +48,7 @@ bool FcOpLite::CheckShape() const {
return true; return true;
} }
bool FcOpLite::SmartInferShape() { bool FcOpLite::InferShapeImpl() const {
if (!last_input_shapes.empty() && !last_output_shapes.empty()) {
if (last_input_shapes[0] == param_.input->dims() &&
last_input_lods[0] == param_.input->lod()) {
param_.output->Resize(last_output_shapes[0]);
param_.output->set_lod(last_output_lods[0]);
return true;
}
}
this->InferShape();
if (!last_input_shapes.empty()) {
last_input_shapes.clear();
last_input_lods.clear();
}
last_input_shapes.push_back(param_.input->dims());
last_input_lods.push_back(param_.input->lod());
if (!last_output_shapes.empty()) {
last_output_shapes.clear();
last_output_lods.clear();
}
last_output_shapes.push_back(param_.output->dims());
last_output_lods.push_back(param_.output->lod());
return true;
}
bool FcOpLite::InferShape() const {
const auto& input_dims = param_.input->dims(); const auto& input_dims = param_.input->dims();
const auto& w_dims = param_.w->dims(); const auto& w_dims = param_.w->dims();
int in_num_col_dims = param_.in_num_col_dims; int in_num_col_dims = param_.in_num_col_dims;
......
...@@ -35,8 +35,7 @@ class FcOpLite : public OpLite { ...@@ -35,8 +35,7 @@ class FcOpLite : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool SmartInferShape() override;
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
......
...@@ -29,7 +29,7 @@ class FeedOp : public OpLite { ...@@ -29,7 +29,7 @@ class FeedOp : public OpLite {
return true; return true;
} }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
......
...@@ -29,7 +29,7 @@ class FetchOp : public OpLite { ...@@ -29,7 +29,7 @@ class FetchOp : public OpLite {
return true; return true;
} }
bool InferShape() const override { return true; } bool InferShapeImpl() const override { return true; }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
protected: protected:
......
...@@ -28,7 +28,7 @@ bool FillConstantBatchSizeLikeOp::CheckShape() const { ...@@ -28,7 +28,7 @@ bool FillConstantBatchSizeLikeOp::CheckShape() const {
return true; return true;
} }
bool FillConstantBatchSizeLikeOp::InferShape() const { bool FillConstantBatchSizeLikeOp::InferShapeImpl() const {
std::vector<int64_t> output_dim{param_.shape.begin(), param_.shape.end()}; std::vector<int64_t> output_dim{param_.shape.begin(), param_.shape.end()};
if (param_.input_dim_idx == 0 && !param_.input->lod().empty()) { if (param_.input_dim_idx == 0 && !param_.input->lod().empty()) {
output_dim[param_.output_dim_idx] = param_.input->lod().back().size() - 1; output_dim[param_.output_dim_idx] = param_.input->lod().back().size() - 1;
......
...@@ -32,7 +32,7 @@ class FillConstantBatchSizeLikeOp : public OpLite { ...@@ -32,7 +32,7 @@ class FillConstantBatchSizeLikeOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -24,7 +24,7 @@ bool FillConstantOp::CheckShape() const { ...@@ -24,7 +24,7 @@ bool FillConstantOp::CheckShape() const {
return true; return true;
} }
bool FillConstantOp::InferShape() const { bool FillConstantOp::InferShapeImpl() const {
std::vector<int64_t> out_shape; std::vector<int64_t> out_shape;
auto shape_tensor = param_.shape_tensor; auto shape_tensor = param_.shape_tensor;
auto shape_tensor_list = param_.shape_tensor_list; auto shape_tensor_list = param_.shape_tensor_list;
......
...@@ -31,7 +31,7 @@ class FillConstantOp : public OpLite { ...@@ -31,7 +31,7 @@ class FillConstantOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -25,7 +25,7 @@ bool FlattenOp::CheckShape() const { ...@@ -25,7 +25,7 @@ bool FlattenOp::CheckShape() const {
return true; return true;
} }
bool FlattenOp::InferShape() const { bool FlattenOp::InferShapeImpl() const {
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
auto out_lod = param_.output->mutable_lod(); auto out_lod = param_.output->mutable_lod();
...@@ -71,8 +71,8 @@ bool Flatten2Op::CheckShape() const { ...@@ -71,8 +71,8 @@ bool Flatten2Op::CheckShape() const {
return true; return true;
} }
bool Flatten2Op::InferShape() const { bool Flatten2Op::InferShapeImpl() const {
FlattenOp::InferShape(); FlattenOp::InferShapeImpl();
auto x_dims = param_.x->dims(); auto x_dims = param_.x->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0); std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0);
for (size_t i = 0; i < x_dims.size(); i++) { for (size_t i = 0; i < x_dims.size(); i++) {
......
...@@ -30,7 +30,7 @@ class FlattenOp : public OpLite { ...@@ -30,7 +30,7 @@ class FlattenOp : public OpLite {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
...@@ -49,7 +49,7 @@ class Flatten2Op : public FlattenOp { ...@@ -49,7 +49,7 @@ class Flatten2Op : public FlattenOp {
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
......
...@@ -27,7 +27,7 @@ bool FusionElementwiseActivationOp::CheckShape() const { ...@@ -27,7 +27,7 @@ bool FusionElementwiseActivationOp::CheckShape() const {
return true; return true;
} }
bool FusionElementwiseActivationOp::InferShape() const { bool FusionElementwiseActivationOp::InferShapeImpl() const {
CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size()); CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size());
param_.Out->Resize(param_.X->dims()); param_.Out->Resize(param_.X->dims());
return true; return true;
...@@ -59,7 +59,7 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc, ...@@ -59,7 +59,7 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc,
// return true; // return true;
// } // }
// bool FusionElementwiseActivationGradExplicitOp::InferShape() const { // bool FusionElementwiseActivationGradExplicitOp::InferShapeImpl() const {
// param_.X_grad->Resize(param_.Out_grad->dims()); // param_.X_grad->Resize(param_.Out_grad->dims());
// param_.Y_grad->Resize(param_.Y->dims()); // param_.Y_grad->Resize(param_.Y->dims());
// return true; // return true;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册