提交 9949681a 编写于 作者: H hjchen2

Aggregate activation operations, add log and lod_reset op

上级 6e6d5dde
......@@ -73,6 +73,8 @@ const char *G_OP_TYPE_SHAPE = "shape";
const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_TOP_K = "top_k";
const char *G_OP_TYPE_CAST = "cast";
const char *G_OP_TYPE_LOG = "log";
const char *G_OP_TYPE_LOD_RESET = "lod_reset";
const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
......@@ -171,5 +173,7 @@ std::unordered_map<
{G_OP_TYPE_SEQUENCE_EXPAND, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_POOL, {{"X"}, {"Out"}}},
{G_OP_TYPE_SEQUENCE_SOFTMAX, {{"X"}, {"Out"}}},
{G_OP_TYPE_NORM, {{"X"}, {"Out", "Norm"}}}};
{G_OP_TYPE_NORM, {{"X"}, {"Out", "Norm"}}},
{G_OP_TYPE_LOG, {{"X"}, {"Out"}}},
{G_OP_TYPE_LOD_RESET, {{"X", "Y"}, {"Out"}}}};
} // namespace paddle_mobile
......@@ -100,6 +100,7 @@ enum ActivationType {
LEAKY_RELU = 4,
TANH = 5,
SIGMOID = 6,
LOG = 7,
};
enum PoolingType {
......@@ -155,6 +156,8 @@ extern const char *G_OP_TYPE_PRELU;
extern const char *G_OP_TYPE_SUM;
extern const char *G_OP_TYPE_TOP_K;
extern const char *G_OP_TYPE_CAST;
extern const char *G_OP_TYPE_LOG;
extern const char *G_OP_TYPE_LOD_RESET;
extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE;
......
......@@ -270,3 +270,6 @@ LOAD_OP1(sequence_expand, CPU);
#ifdef SEQUENCE_POOL_OP
LOAD_OP1(sequence_pool, CPU);
#endif
#ifdef LOG_OP
LOAD_OP1(log, CPU);
#endif
......@@ -36,13 +36,12 @@ limitations under the License. */
#include "framework/cl/cl_helper.h"
#include "framework/cl/cl_scope.h"
#endif
namespace paddle_mobile {
namespace framework {
using std::string;
using std::vector;
template <typename T>
static T *GetVarValue(const string &key, const VariableNameMap &var_map,
static T *GetVarValue(const std::string &key, const VariableNameMap &var_map,
const Scope &scope) {
auto var_vec = var_map.at(key);
if (!var_vec.empty()) {
......@@ -56,44 +55,29 @@ static T *GetVarValue(const string &key, const VariableNameMap &var_map,
template <typename Dtype>
class OperatorBase {
public:
/*
* @b op 基类的实例化方法, op 获取到了 输入、参数以及提前分配好的输出 tensor
* */
OperatorBase(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs,
std::shared_ptr<Scope> scope);
virtual ~OperatorBase() {}
void Run();
std::vector<string> GetOutKeys() const;
std::vector<string> GetInputKeys() const;
virtual void RunImpl() = 0;
virtual void Init() = 0;
/*
* @b op 运算所需的输入, 如上一层的输出结果、卷积核
* */
virtual void InferShape() const = 0;
virtual void Run();
virtual void RunImpl() = 0;
std::vector<std::string> GetOutKeys() const;
std::vector<std::string> GetInputKeys() const;
const VariableNameMap &Inputs() const { return inputs_; }
/*
* @b op 的输出, 内存会提前被分配好, 运算结果会被存到分配好的内存内
* */
const VariableNameMap &Outputs() const { return outputs_; }
/*
* @b op 类型
* */
const std::string &Type() const { return type_; }
/*
* @b op 运算所需要用到的参数: 如 conv 运算所需要用到的 stride
* */
const AttributeMap &Attrs() const { return attrs_; }
void ClearVariables(const std::vector<std::string> &var_names) const {
if (this->scope_) {
this->scope_->EraseVars(var_names);
}
}
/*
* @b 根据输入形状和参数计算出输出形状
* */
virtual void InferShape() const = 0;
protected:
std::shared_ptr<Scope> scope_;
......@@ -106,9 +90,6 @@ class OperatorBase {
void CheckAllInputOutputSet() const;
};
/*
* @b 这个类为所有带有运算的 op 的父类, 这个 op 继承与 OperatorBase
* */
template <typename Dtype, typename ParamType, typename KernelType>
class OperatorWithKernel : public OperatorBase<Dtype> {
public:
......@@ -136,9 +117,6 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
ParamType param_;
};
/*
* @b 所有kernel的父类
* */
template <typename Dtype, typename P>
class OpKernelBase {
public:
......@@ -150,11 +128,6 @@ class OpKernelBase {
}
#endif
/*
* @b 所有kernel 需实现 Compute 方法
* @p para 这个参数为 kernel 运算时所需要用到参数组成的一个结构体,
* 所有结构体存在与: paddle-mobile/src/operators/op_param.h
* */
#ifdef PADDLE_McOBILE_MALI_GPU
OpKernelBase() { acl_op_ = nullptr; }
void *GetAclOp() const { return acl_op_; }
......@@ -177,6 +150,23 @@ class OpKernelBase {
#endif
};
#define DECLARE_OPERATOR(OpName, OpParam, OpKernel) \
template <typename DeviceType, typename T> \
class OpName##Op : public framework::OperatorWithKernel< \
DeviceType, OpParam<DeviceType>, \
operators::OpKernel<DeviceType, T>> { \
public: \
OpName##Op(const std::string &type, const VariableNameMap &inputs, \
const VariableNameMap &outputs, \
const framework::AttributeMap &attrs, \
std::shared_ptr<framework::Scope> scope) \
: framework::OperatorWithKernel<DeviceType, OpParam<DeviceType>, \
operators::OpKernel<DeviceType, T>>( \
type, inputs, outputs, attrs, scope) {} \
\
void InferShape() const override; \
};
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls) \
cls(const std::string &type, const ::paddle_mobile::VariableNameMap &inputs, \
const ::paddle_mobile::VariableNameMap &outputs, \
......@@ -202,7 +192,6 @@ class FusionOpMatcher {
virtual std::vector<std::pair<int, std::string>> NeedCheck() { return {}; }
// virtual bool Fusion();
protected:
Node node_;
std::string type_;
......
......@@ -12,29 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RELU_OP
#include "operators/relu_op.h"
#include "operators/activation_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void ReluOp<Dtype, T>::InferShape() const {
auto input_dims = this->param_.InputX()->dims();
this->param_.Out()->Resize(input_dims);
}
#define DEFINE_INFERSHAPE(OpName) \
template <typename Dtype, typename T> \
void OpName##Op<Dtype, T>::InferShape() const { \
const auto &input_dims = this->param_.InputX()->dims(); \
this->param_.Out()->Resize(input_dims); \
}
#ifdef RELU_OP
DEFINE_INFERSHAPE(Relu);
DEFINE_INFERSHAPE(Relu6);
#endif // RELU_OP
template <typename Dtype, typename T>
void Relu6Op<Dtype, T>::InferShape() const {
auto input_dims = this->param_.InputX()->dims();
this->param_.Out()->Resize(input_dims);
}
#ifdef SIGMOID_OP
DEFINE_INFERSHAPE(Sigmoid);
#endif // SIGMOID_OP
#ifdef TANH_OP
DEFINE_INFERSHAPE(Tanh);
#endif // TANH_OP
#ifdef LOG_OP
DEFINE_INFERSHAPE(Log);
#endif // LOG_OP
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef RELU_OP
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(relu, ops::ReluOp);
REGISTER_OPERATOR_CPU(relu6, ops::Relu6Op);
......@@ -47,5 +58,23 @@ REGISTER_OPERATOR_MALI_GPU(relu, ops::ReluOp);
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(relu, ops::ReluOp);
#endif
#endif // RELU_OP
#ifdef SIGMOID_OP
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(sigmoid, ops::SigmoidOp);
#endif
#endif // SIGMOID_OP
#ifdef TANH_OP
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(tanh, ops::TanhOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(tanh, ops::TanhOp);
#endif
#endif // TANH_OP
#ifdef LOG_OP
REGISTER_OPERATOR_CPU(log, ops::LogOp);
#endif // LOG_OP
......@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TANH_OP
#pragma once
#include <string>
......@@ -24,21 +22,22 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class TanhOp : public framework::OperatorWithKernel<
DeviceType, TanhParam<DeviceType>,
operators::TanhKernel<DeviceType, T>> {
public:
TanhOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, TanhParam<DeviceType>,
operators::TanhKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
};
#ifdef RELU_OP
DECLARE_OPERATOR(Relu, ReluParam, ReluKernel);
DECLARE_OPERATOR(Relu6, ReluParam, Relu6Kernel);
#endif
} // namespace operators
} // namespace paddle_mobile
#ifdef SIGMOID_OP
DECLARE_OPERATOR(Sigmoid, SigmoidParam, SigmoidKernel);
#endif
#ifdef TANH_OP
DECLARE_OPERATOR(Tanh, TanhParam, TanhKernel);
#endif
#ifdef LOG_OP
DECLARE_OPERATOR(Log, ReluParam, LogKernel);
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -20,26 +20,30 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
#define DECLARE_KERNEL(KernelClass, KernelParam) \
template <typename DeviceType, typename T> \
class KernelClass \
: public framework::OpKernelBase<DeviceType, KernelParam<DeviceType>> { \
public: \
bool Init(KernelParam<DeviceType> *param); \
void Compute(const KernelParam<DeviceType> &param); \
#define DECLARE_KERNEL(OpName, Param) \
template <typename DeviceType, typename T> \
class OpName##Kernel \
: public framework::OpKernelBase<DeviceType, Param<DeviceType>> { \
public: \
bool Init(Param<DeviceType> *param); \
void Compute(const Param<DeviceType> &param); \
};
#ifdef RELU_OP
DECLARE_KERNEL(ReluKernel, ReluParam);
DECLARE_KERNEL(Relu6Kernel, ReluParam);
DECLARE_KERNEL(Relu, ReluParam);
DECLARE_KERNEL(Relu6, ReluParam);
#endif
#ifdef SIGMOID_OP
DECLARE_KERNEL(SigmoidKernel, SigmoidParam);
DECLARE_KERNEL(Sigmoid, SigmoidParam);
#endif
#ifdef TANH_OP
DECLARE_KERNEL(TanhKernel, TanhParam);
DECLARE_KERNEL(Tanh, TanhParam);
#endif
#ifdef LOG_OP
DECLARE_KERNEL(Log, ReluParam);
#endif
} // namespace operators
......
......@@ -105,7 +105,7 @@ void SigmoidKernel<CPU, float>::Compute(const SigmoidParam<CPU> &param) {
#ifdef TANH_OP
template <>
void TanhKernel<CPU, float>::Init(TanhParam<CPU> *param) {
bool TanhKernel<CPU, float>::Init(TanhParam<CPU> *param) {
return true;
}
......@@ -117,5 +117,19 @@ void TanhKernel<CPU, float>::Compute(const TanhParam<CPU> &param) {
}
#endif
#ifdef LOG_OP
template <>
bool LogKernel<CPU, float>::Init(ReluParam<CPU> *param) {
return true;
}
template <>
void LogKernel<CPU, float>::Compute(const ReluParam<CPU> &param) {
const Tensor *input = param.InputX();
Tensor *output = param.Out();
ActivationCompute<float, LOG>()(input, output);
}
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -12,34 +12,51 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SIGMOID_OP
#ifdef LOD_RESET_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/activation_kernel.h"
#include "operators/op_param.h"
#include "operators/kernel/kernels.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class SigmoidOp : public framework::OperatorWithKernel<
DeviceType, SigmoidParam<DeviceType>,
operators::SigmoidKernel<DeviceType, T>> {
public:
SigmoidOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, SigmoidParam<DeviceType>,
operators::SigmoidKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
};
template <>
bool LodResetKernel<CPU, float>::Init(LodResetParam<CPU> *param) {
return true;
}
template <>
void LodResetKernel<CPU, float>::Compute(const LodResetParam<CPU> &param) {
const auto *input = param.input_x_;
const auto *lod_t = param.input_y_;
auto *output = param.output_;
output->ShareDataWith(*input);
std::vector<int> level0;
if (lod_t) {
if (lod_t->lod().size() > 0) {
output->set_lod(lod_t->lod());
return; // early return, since lod already set
} else {
auto *lod = lod_t->data<int>();
level0 = std::vector<int>(lod, lod + lod_t->numel());
}
} else {
level0 = param.target_lod_;
}
// cast level0 to size_t
std::vector<size_t> ulevel0(level0.size(), 0);
for (int i = 0; i < level0.size(); ++i) {
ulevel0[i] = level0[i];
}
framework::LoD target_lod;
target_lod.push_back(std::move(ulevel0));
output->set_lod(target_lod);
}
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // LOD_RESET_OP
......@@ -22,7 +22,7 @@ namespace operators {
#define DECLARE_KERNEL(KernelClass, KernelParam) \
template <typename DeviceType, typename T> \
class KernelClass \
class KernelClass##Kernel \
: public framework::OpKernelBase<DeviceType, KernelParam<DeviceType>> { \
public: \
bool Init(KernelParam<DeviceType> *param); \
......@@ -30,12 +30,16 @@ namespace operators {
};
#ifdef TOP_K_OP
DECLARE_KERNEL(TopKKernel, TopKParam)
DECLARE_KERNEL(TopK, TopKParam)
#endif // TOP_K_OP
#ifdef CAST_OP
DECLARE_KERNEL(CastKernel, CastParam)
DECLARE_KERNEL(Cast, CastParam)
#endif // CAST_OP
#ifdef LOD_RESET_OP
DECLARE_KERNEL(LodReset, LodResetParam)
#endif // LOD_RESET_OP
} // namespace operators
} // namespace paddle_mobile
......@@ -12,16 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SIGMOID_OP
#ifdef LOD_RESET_OP
#include "operators/sigmoid_op.h"
#include "operators/lod_reset_op.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void SigmoidOp<DeviceType, T>::InferShape() const {
this->param_.Out()->Resize(this->param_.InputX()->dims());
template <typename Dtype, typename T>
void LodResetOp<Dtype, T>::InferShape() const {
const auto &input_dims = this->param_.input_->dims();
this->param_.output_->Resize(input_dims);
}
} // namespace operators
......@@ -29,7 +30,7 @@ void SigmoidOp<DeviceType, T>::InferShape() const {
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(sigmoid, ops::SigmoidOp);
REGISTER_OPERATOR_CPU(lod_reset, ops::LodResetOp);
#endif
#endif
#endif // LOD_RESET_OP
......@@ -12,27 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TANH_OP
#ifdef LOD_RESET_OP
#include "operators/tanh_op.h"
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/kernels.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void TanhOp<DeviceType, T>::InferShape() const {
this->param_.Out()->Resize(this->param_.InputX()->dims());
}
DECLARE_OPERATOR(LodReset, LodResetParam, LodResetKernel);
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(tanh, ops::TanhOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(tanh, ops::TanhOp);
#endif
#endif
#endif // LOD_RESET_OP
......@@ -86,6 +86,11 @@ inline float32x4_t vActiveq_f32<TANH>(const float32x4_t &x) {
__out = vmulq_n_f32(__out, 2.f);
return vsubq_f32(__out, __one);
}
template <>
inline float32x4_t vActiveq_f32<LOG>(const float32x4_t &x) {
return log_ps(x);
}
#endif
template <ActivationType Act = IDENTITY>
......@@ -119,6 +124,11 @@ inline float Active<TANH>(const float &x) {
return 2.f / (1.f + exp(-2.f * x)) - 1.f;
}
template <>
inline float Active<LOG>(const float &x) {
return log(x);
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -2829,5 +2829,32 @@ class SequencePoolParam : public OpParam {
};
#endif // SEQUENCE_EXPAND_OP
#ifdef LOD_RESET_OP
template <typename Dtype>
class LodResetParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
LodResetParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
input_y_ = nullptr;
if (inputs.count("Y")) {
input_y_ = InputYFrom<GType>(inputs, scope);
} else {
target_lod_ = OpParam::GetAttr<vector<int>>("target_lod", attrs);
}
}
public:
GType *input_x_;
GType *input_y_;
GType *output_;
std::vector<int> target_lod_;
};
#endif // LOD_RESET_OP
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RELU_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/activation_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class ReluOp : public framework::OperatorWithKernel<
DeviceType, ReluParam<DeviceType>,
operators::ReluKernel<DeviceType, T>> {
public:
ReluOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, ReluParam<DeviceType>,
operators::ReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
};
template <typename DeviceType, typename T>
class Relu6Op : public framework::OperatorWithKernel<
DeviceType, ReluParam<DeviceType>,
operators::Relu6Kernel<DeviceType, T>> {
public:
Relu6Op(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, ReluParam<DeviceType>,
operators::Relu6Kernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -238,6 +238,12 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-relu6-op operators/test_relu6_op.cpp test_helper.h test_include.h)
target_link_libraries(test-relu6-op paddle-mobile)
ADD_EXECUTABLE(test-tanh-op operators/test_tanh_op.cpp test_helper.h test_include.h)
target_link_libraries(test-tanh-op paddle-mobile)
ADD_EXECUTABLE(test-log-op operators/test_log_op.cpp test_helper.h test_include.h)
target_link_libraries(test-log-op paddle-mobile)
ADD_EXECUTABLE(test-topk-op operators/test_topk_op.cpp test_helper.h test_include.h)
target_link_libraries(test-topk-op paddle-mobile)
......
......@@ -20,12 +20,11 @@ limitations under the License. */
#include "common/log.h"
#include "framework/executor.h"
#include "framework/op_registry.h"
#include "operators/activation_op.h"
#include "operators/conv_op.h"
#include "operators/elementwise_add_op.h"
#include "operators/pool_op.h"
#include "operators/relu_op.h"
#include "operators/reshape_op.h"
#include "operators/sigmoid_op.h"
#include "operators/softmax_op.h"
#include "operators/transpose_op.h"
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <iostream>
#include "../test_include.h"
#include "operators/activation_op.h"
namespace paddle_mobile {
void Log(const framework::Tensor *X, framework::Tensor *Y) {
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
for (int i = 0; i < X->numel(); ++i) {
y[i] = log(x[i]);
}
}
int TestLogOp(const std::vector<int> input_shape) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, 0.0001, 100.0);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
auto *op =
new operators::LogOp<CPU, float>("log", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Log(input, &output_cmp);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestLogOp({1, 1, 2, 3});
paddle_mobile::TestLogOp({1, 3, 11, 22});
paddle_mobile::TestLogOp({1, 32, 112, 112});
std::cout << "test log op pass." << std::endl;
return 0;
}
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <cmath>
#include <iostream>
#include "../test_include.h"
#include "operators/relu_op.h"
#include "operators/activation_op.h"
namespace paddle_mobile {
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <cmath>
#include <iostream>
#include "../test_include.h"
#include "operators/relu_op.h"
#include "operators/activation_op.h"
namespace paddle_mobile {
......
......@@ -12,15 +12,70 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_helper.h"
#include <cmath>
#include <iostream>
#include "../test_include.h"
#include "operators/activation_op.h"
int main() {
paddle_mobile::framework::Tensor input;
paddle_mobile::framework::Tensor output;
SetupTensor<float>(&input, {1, 4, 60, 60}, static_cast<float>(0),
static_cast<float>(1));
namespace paddle_mobile {
void Sigmoid(const framework::Tensor *X, framework::Tensor *Y) {
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
for (int i = 0; i < X->numel(); ++i) {
y[i] = 1.f / (1.f + exp(-x[i]));
}
}
int TestSigmoidOp(const std::vector<int> input_shape) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
auto *op = new operators::SigmoidOp<CPU, float>("sigmoid", inputs, outputs,
attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto out_ddim = paddle_mobile::framework::make_ddim({1, 4, 60, 60});
output.Resize(out_ddim);
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Sigmoid(input, &output_cmp);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestSigmoidOp({1, 1, 2, 3});
paddle_mobile::TestSigmoidOp({1, 3, 11, 22});
paddle_mobile::TestSigmoidOp({1, 32, 112, 112});
std::cout << "test sigmoid op pass." << std::endl;
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <iostream>
#include "../test_include.h"
#include "operators/activation_op.h"
namespace paddle_mobile {
void Tanh(const framework::Tensor *X, framework::Tensor *Y) {
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
for (int i = 0; i < X->numel(); ++i) {
y[i] = 2.f / (1.f + exp(-2.f * x[i])) - 1.f;
}
}
int TestTanhOp(const std::vector<int> input_shape) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
auto output_var = scope.get()->Var("output");
framework::AttributeMap attrs;
auto *op =
new operators::TanhOp<CPU, float>("tanh", inputs, outputs, attrs, scope);
op->InferShape();
op->Init();
op->Run();
auto output = output_var->template Get<framework::LoDTensor>();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Tanh(input, &output_cmp);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main() {
paddle_mobile::TestTanhOp({1, 1, 2, 3});
paddle_mobile::TestTanhOp({1, 3, 11, 22});
paddle_mobile::TestTanhOp({1, 32, 112, 112});
std::cout << "test sigmoid op pass." << std::endl;
return 0;
}
......@@ -276,6 +276,8 @@ if(NOT FOUND_MATCH)
set(SEQUENCE_EXPAND_OP ON)
set(SEQUENCE_POOL_OP ON)
set(SEQUENCE_SOFTMAX_OP ON)
set(LOG_OP ON)
set(TANH_OP ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -512,6 +514,9 @@ endif()
if (SEQUENCE_SOFTMAX_OP)
add_definitions(-DSEQUENCE_SOFTMAX_OP)
endif()
if (LOG_OP)
add_definitions(-DLOG_OP)
endif()
if (TANH_OP)
add_definitions(-DTANH_OP)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册