提交 d9400243 编写于 作者: Q qiaolongfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into public_to_protected

...@@ -34,9 +34,6 @@ RUN apt-get update && \ ...@@ -34,9 +34,6 @@ RUN apt-get update && \
net-tools && \ net-tools && \
apt-get clean -y apt-get clean -y
# paddle is using numpy.flip, which is introduced since 1.12.0
RUN pip --no-cache-dir install 'numpy>=1.12.0'
# Install Go and glide # Install Go and glide
RUN wget -qO- https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz | \ RUN wget -qO- https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz | \
tar -xz -C /usr/local && \ tar -xz -C /usr/local && \
...@@ -58,13 +55,16 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8 ...@@ -58,13 +55,16 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8
# FIXME: due to temporary ipykernel dependency issue, specify ipykernel jupyter # FIXME: due to temporary ipykernel dependency issue, specify ipykernel jupyter
# version util jupyter fixes this issue. # version util jupyter fixes this issue.
RUN pip install --upgrade pip && \ RUN pip install --upgrade pip && \
pip install -U 'protobuf==3.1.0' && \ pip install -U wheel && \
pip install -U wheel pillow BeautifulSoup && \
pip install -U docopt PyYAML sphinx && \ pip install -U docopt PyYAML sphinx && \
pip install -U sphinx-rtd-theme==0.1.9 recommonmark && \ pip install -U sphinx-rtd-theme==0.1.9 recommonmark
pip install pre-commit 'requests==2.9.2' 'ipython==5.3.0' && \
RUN pip install pre-commit 'ipython==5.3.0' && \
pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip install opencv-python rarfile 'scipy>=0.19.0' 'nltk>=3.2.2' pip install opencv-python
COPY ./python/requirements.txt /root/
RUN pip install -r /root/requirements.txt
# To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use # To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use
# the solution in https://urllib3.readthedocs.io/en/latest/user-guide.html#ssl-py2 # the solution in https://urllib3.readthedocs.io/en/latest/user-guide.html#ssl-py2
......
...@@ -28,13 +28,6 @@ using OpAttrChecker = framework::OpAttrChecker; ...@@ -28,13 +28,6 @@ using OpAttrChecker = framework::OpAttrChecker;
using Scope = framework::Scope; using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext; using DeviceContext = platform::DeviceContext;
class EmptyOp : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {}
};
class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public: public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
...@@ -155,19 +148,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker { ...@@ -155,19 +148,16 @@ class AddOpMaker : public OpProtoAndCheckerMaker {
namespace f = paddle::framework; namespace f = paddle::framework;
namespace ops = paddle::operators; namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet; using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker); REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad,
REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp); f::NOP);
REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker); REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp); REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker); REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker);
REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker); REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP);
REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); f::NOP);
REGISTER_OP(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker);
REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp);
TEST(Backward, simple_op_grad) { TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp( auto fwd = f::OpRegistry::CreateOp(
......
...@@ -19,16 +19,14 @@ namespace paddle { ...@@ -19,16 +19,14 @@ namespace paddle {
namespace framework { namespace framework {
enum class OpArgType { IN, OUT }; enum class OpArgType { IN, OUT };
static void TransOpArg(const OperatorBase* src_op, static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
OperatorBase::VarNameMap* vars, bool is_grad, OperatorBase::VarNameMap* vars) {
const OpArgType& src_type, bool is_grad) {
const auto& src_inout = const auto& src_inout =
src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs(); src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs();
auto& dst_inout = *vars; auto& dst_inout = *vars;
const OpProto* proto = OpRegistry::op_info_map().at(src_op->Type()).proto_;
const OpProto& proto = OpProtos().at(src_op->Type());
const auto& src_arg_list = const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); src_type == OpArgType::IN ? proto->inputs() : proto->outputs();
for (const auto& arg : src_arg_list) { for (const auto& arg : src_arg_list) {
if (arg.no_gradient() && !is_grad) continue; if (arg.no_gradient() && !is_grad) continue;
const std::string src_name = arg.name(); const std::string src_name = arg.name();
...@@ -42,22 +40,26 @@ static void TransOpArg(const OperatorBase* src_op, ...@@ -42,22 +40,26 @@ static void TransOpArg(const OperatorBase* src_op,
} }
OperatorBase* BuildGradOp(const OperatorBase* op) { OperatorBase* BuildGradOp(const OperatorBase* op) {
auto gop_type_it = OpRegistry::grad_ops().find(op->Type()); auto it = OpRegistry::op_info_map().find(op->Type());
PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(), PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"Operator %s do not register gradient type", op->Type()); "'%s' has not been registered.", op->Type());
auto& grad_op_type = gop_type_it->second; PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.",
op->Type());
std::string grad_op_type = it->second.grad_op_type_;
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
op->Type());
OperatorBase::VarNameMap inputs; OperatorBase::VarNameMap inputs;
OperatorBase::VarNameMap outputs; OperatorBase::VarNameMap outputs;
TransOpArg(op, &inputs, OpArgType::IN, false); // I TransOpArg(op, OpArgType::IN, false, &inputs); // I
TransOpArg(op, &inputs, OpArgType::OUT, false); // O TransOpArg(op, OpArgType::OUT, false, &inputs); // O
TransOpArg(op, &inputs, OpArgType::OUT, true); // OG TransOpArg(op, OpArgType::OUT, true, &inputs); // OG
TransOpArg(op, &outputs, OpArgType::IN, true); // IG TransOpArg(op, OpArgType::IN, true, &outputs); // IG
auto gop_it = OpRegistry::op_creators().find(grad_op_type);
PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(),
"Operator %s 's Gradient %s's creator cannot be found",
op->Type(), grad_op_type);
return gop_it->second(grad_op_type, inputs, outputs, op->Attrs()); it = OpRegistry::op_info_map().find(grad_op_type);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", grad_op_type);
return it->second.creator_(grad_op_type, inputs, outputs, op->Attrs());
} }
} // namespace framework } // namespace framework
......
...@@ -8,14 +8,6 @@ USE_OP(add_two); ...@@ -8,14 +8,6 @@ USE_OP(add_two);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class NOP : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void InferShape(const Scope &scope) const override {}
void Run(const Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
};
class MutiInOutOpMaker : public OpProtoAndCheckerMaker { class MutiInOutOpMaker : public OpProtoAndCheckerMaker {
public: public:
MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker)
...@@ -62,10 +54,8 @@ TEST(GradOpBuilder, AddTwo) { ...@@ -62,10 +54,8 @@ TEST(GradOpBuilder, AddTwo) {
EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y")); EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y"));
} }
REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker); REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP);
REGISTER_GRADIENT_OP(mult_io, mult_io_grad, f::NOP); REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP);
REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker);
REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP);
TEST(GradOpBuilder, MutiInOut) { TEST(GradOpBuilder, MutiInOut) {
std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp( std::shared_ptr<f::OperatorBase> test_op(f::OpRegistry::CreateOp(
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include <type_traits> #include <type_traits>
#include <typeinfo>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
...@@ -119,6 +120,12 @@ class OpProtoAndCheckerMaker { ...@@ -119,6 +120,12 @@ class OpProtoAndCheckerMaker {
bool validated_{false}; bool validated_{false};
}; };
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
class OpRegistry { class OpRegistry {
using VarNameMap = OperatorBase::VarNameMap; using VarNameMap = OperatorBase::VarNameMap;
using OpCreator = std::function<OperatorBase*( using OpCreator = std::function<OperatorBase*(
...@@ -126,45 +133,56 @@ class OpRegistry { ...@@ -126,45 +133,56 @@ class OpRegistry {
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
public: public:
template <typename OpType, typename ProtoMakerType> struct OpInfo {
static void RegisterOp(const std::string& op_type) { OpCreator creator_;
op_creators()[op_type] = []( std::string grad_op_type_;
const std::string& type, const VarNameMap& inputs, OpProto* proto_;
const VarNameMap& outputs, const AttributeMap& attrs) { OpAttrChecker* checker_;
};
template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) {
PADDLE_ENFORCE(op_info_map().count(op_type) == 0,
"'%s' is registered more than once.", op_type);
OpInfo op_info;
op_info.creator_ = [](const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs,
const AttributeMap& attrs) {
return new OpType(type, inputs, outputs, attrs); return new OpType(type, inputs, outputs, attrs);
}; };
OpAttrChecker& op_checker = op_checkers()[op_type]; op_info.grad_op_type_ = grad_op_type;
OpProto& op_proto = OpProtos()[op_type]; if (std::type_index(typeid(ProtoMakerType)) !=
auto maker = ProtoMakerType(&op_proto, &op_checker); std::type_index(typeid(NOPMaker))) {
op_info.proto_ = new OpProto;
op_info.checker_ = new OpAttrChecker;
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
maker.Validate(); maker.Validate();
op_proto.set_type(op_type); op_info.proto_->set_type(op_type);
PADDLE_ENFORCE( PADDLE_ENFORCE(
op_proto.IsInitialized(), op_info.proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized", "Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString()); op_type, op_info.proto_->InitializationErrorString());
} else {
op_info.proto_ = nullptr;
op_info.checker_ = nullptr;
}
op_info_map().insert(std::make_pair(op_type, op_info));
// register gradient op
if (!grad_op_type.empty()) {
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
} }
template <typename GradOpType>
static void RegisterGradOp(const std::string& op_type,
const std::string& grad_op_type) {
op_creators()[grad_op_type] = [](
const std::string& type, const VarNameMap& inputs,
const VarNameMap& outputs, const AttributeMap& attrs) {
return new GradOpType(type, inputs, outputs, attrs);
};
grad_ops()[op_type] = grad_op_type;
} }
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameMap& inputs, const VarNameMap& inputs,
const VarNameMap& outputs, const VarNameMap& outputs,
AttributeMap attrs) { AttributeMap attrs) {
auto op_create_it = op_creators().find(type); auto it = op_info_map().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(), PADDLE_ENFORCE(it != op_info_map().end(),
"Operator %s cannot be found.", type); "Operator '%s' has not been registered.", type);
op_checkers().at(type).Check(attrs); it->second.checker_->Check(attrs);
auto op = op_create_it->second(type, inputs, outputs, attrs); auto op = it->second.creator_(type, inputs, outputs, attrs);
return std::shared_ptr<OperatorBase>(op); return std::shared_ptr<OperatorBase>(op);
} }
...@@ -199,49 +217,32 @@ class OpRegistry { ...@@ -199,49 +217,32 @@ class OpRegistry {
return grad_op; return grad_op;
} }
static std::unordered_map<std::string, std::string>& grad_ops() { static std::unordered_map<std::string, const OpInfo>& op_info_map() {
static std::unordered_map<std::string, std::string> grad_ops_; static std::unordered_map<std::string, const OpInfo> op_info_map_;
return grad_ops_; return op_info_map_;
}
static std::unordered_map<std::string, OpCreator>& op_creators() {
static std::unordered_map<std::string, OpCreator> op_creators_;
return op_creators_;
}
private:
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
} }
}; };
class Registrar { class Registrar {
public: public:
// In our design, various kinds of classes, e.g., operators and kernels, have // In our design, various kinds of classes, e.g., operators and kernels,
// their corresponding registry and registrar. The action of registration is // have their corresponding registry and registrar. The action of
// in the constructor of a global registrar variable, which, however, are not // registration is in the constructor of a global registrar variable, which,
// used in the code that calls package framework, and would be removed from // however, are not used in the code that calls package framework, and would
// the generated binary file by the linker. To avoid such removal, we add // be removed from the generated binary file by the linker. To avoid such
// Touch to all registrar classes and make USE_OP macros to call this // removal, we add Touch to all registrar classes and make USE_OP macros to
// method. So, as long as the callee code calls USE_OP, the global // call this method. So, as long as the callee code calls USE_OP, the global
// registrar variable won't be removed by the linker. // registrar variable won't be removed by the linker.
void Touch() {} void Touch() {}
}; };
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType, typename GradOpType>
class OpRegistrar : public Registrar { class OpRegistrar : public Registrar {
public: public:
explicit OpRegistrar(const char* op_type) { explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type); OpRegistrar(const char* op_type, const char* grad_op_type) {
} OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type,
}; grad_op_type);
template <typename GradOpType>
class GradOpRegistrar : public Registrar {
public:
GradOpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
} }
}; };
...@@ -267,30 +268,20 @@ class OpKernelRegistrar : public Registrar { ...@@ -267,30 +268,20 @@ class OpKernelRegistrar : public Registrar {
/** /**
* Macro to register Operator. * Macro to register Operator.
*/ */
#define REGISTER_OP(op_type, op_class, op_maker_class) \ #define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \
grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \ static ::paddle::framework::OpRegistrar<op_class, op_maker_class, \
__op_registrar_##op_type##__(#op_type); \ grad_op_class> \
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
int TouchOpRegistrar_##op_type() { \ int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \ __op_registrar_##op_type##__.Touch(); \
return 0; \ return 0; \
} }
/** #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
* Macro to register Gradient Operator. REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP)
*/
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##grad_op_type, \
"REGISTER_GRADIENT_OP must be called in global namespace"); \
static ::paddle::framework::GradOpRegistrar<grad_op_class> \
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
#grad_op_type); \
int TouchOpGradientRegistrar_##op_type() { \
__op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \
return 0; \
}
/** /**
* Macro to register OperatorKernel. * Macro to register OperatorKernel.
...@@ -306,14 +297,6 @@ class OpKernelRegistrar : public Registrar { ...@@ -306,14 +297,6 @@ class OpKernelRegistrar : public Registrar {
return 0; \ return 0; \
} }
/**
* Macro to Forbid user register Gradient Operator.
*/
#define NO_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_gradient_op__##op_type##_##op_type##_grad, \
"NO_GRADIENT must be called in global namespace")
#define REGISTER_OP_GPU_KERNEL(op_type, ...) \ #define REGISTER_OP_GPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
...@@ -332,23 +315,6 @@ class OpKernelRegistrar : public Registrar { ...@@ -332,23 +315,6 @@ class OpKernelRegistrar : public Registrar {
static int use_op_itself_##op_type##_ __attribute__((unused)) = \ static int use_op_itself_##op_type##_ __attribute__((unused)) = \
TouchOpRegistrar_##op_type() TouchOpRegistrar_##op_type()
// TODO(fengjiayi): Most ops' gradient op have not been compeleted. So we use
// `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't
// be compiled. `NO_GRAD` should be removed after all gradient ops are
// compeleted.
#define NO_GRAD
#ifndef NO_GRAD
#define USE_OP_GRADIENT(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_gradient_##op_type, \
"USE_OP_GRADIENT must be called in global namespace"); \
extern int TouchOpGradientRegistrar_##op_type(); \
static int use_op_gradient_##op_type##_ __attribute__((unused)) = \
TouchOpGradientRegistrar_##op_type()
#else
#define USE_OP_GRADIENT(op_type)
#endif
#define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \ #define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \ __use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
...@@ -368,18 +334,13 @@ class OpKernelRegistrar : public Registrar { ...@@ -368,18 +334,13 @@ class OpKernelRegistrar : public Registrar {
USE_OP_DEVICE_KERNEL(op_type, GPU) USE_OP_DEVICE_KERNEL(op_type, GPU)
#endif #endif
#define USE_NO_GRAD_OP(op_type) \ #define USE_CPU_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type)
#define USE_CPU_OP(op_type) \
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU); \ USE_OP_DEVICE_KERNEL(op_type, CPU);
USE_OP_GRADIENT(op_type)
#define USE_OP(op_type) \ #define USE_OP(op_type) \
USE_NO_GRAD_OP(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_GRADIENT(op_type) USE_OP_KERNEL(op_type)
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -59,10 +59,9 @@ static void BuildVar(const std::string& param_name, ...@@ -59,10 +59,9 @@ static void BuildVar(const std::string& param_name,
var->add_arguments(arg_name); var->add_arguments(arg_name);
} }
} }
REGISTER_OP_WITHOUT_GRADIENT(cos_sim, paddle::framework::CosineOp,
REGISTER_OP(cos_sim, paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker); paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP(my_test_op, paddle::framework::MyTestOp, REGISTER_OP_WITHOUT_GRADIENT(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker); paddle::framework::MyTestOpProtoAndCheckerMaker);
TEST(OpRegistry, CreateOp) { TEST(OpRegistry, CreateOp) {
......
...@@ -33,14 +33,6 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ...@@ -33,14 +33,6 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
} }
#endif #endif
static std::unordered_map<std::string, OpProto>* g_op_protos = nullptr;
std::unordered_map<std::string, OpProto>& OpProtos() {
if (g_op_protos == nullptr) {
g_op_protos = new std::unordered_map<std::string, OpProto>();
}
return *g_op_protos;
}
const std::string& OperatorBase::Input(const std::string& name) const { const std::string& OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name); auto& ins = Inputs(name);
PADDLE_ENFORCE_EQ(ins.size(), 1UL, PADDLE_ENFORCE_EQ(ins.size(), 1UL,
...@@ -149,14 +141,18 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const { ...@@ -149,14 +141,18 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
} }
return ret_val; return ret_val;
} }
auto it = OpProtos().find(type_); auto it = OpRegistry::op_info_map().find(type_);
PADDLE_ENFORCE( PADDLE_ENFORCE(
it != OpProtos().end(), it != OpRegistry::op_info_map().end(),
"Operator %s not registered, cannot figure out intermediate outputs", "Operator %s not registered, cannot figure out intermediate outputs",
type_); type_);
PADDLE_ENFORCE(
it->second.proto_ != nullptr,
"Operator %s has no OpProto, cannot figure out intermediate outputs",
type_);
// get all OpProto::Var for outputs // get all OpProto::Var for outputs
for (auto& o : it->second.outputs()) { for (auto& o : it->second.proto_->outputs()) {
// ignore all intermediate output // ignore all intermediate output
if (o.intermediate()) continue; if (o.intermediate()) continue;
auto out = outputs_.find(o.name()); auto out = outputs_.find(o.name());
......
...@@ -50,8 +50,6 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -50,8 +50,6 @@ inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix; return var_name + kGradVarSuffix;
} }
extern std::unordered_map<std::string, OpProto>& OpProtos();
class OperatorBase; class OperatorBase;
class InferShapeContext; class InferShapeContext;
class ExecutionContext; class ExecutionContext;
...@@ -132,6 +130,14 @@ class OperatorBase { ...@@ -132,6 +130,14 @@ class OperatorBase {
AttributeMap attrs_; AttributeMap attrs_;
}; };
class NOP : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
};
class InferShapeContext { class InferShapeContext {
public: public:
InferShapeContext(const OperatorBase& op, const Scope& scope) InferShapeContext(const OperatorBase& op, const Scope& scope)
...@@ -213,7 +219,7 @@ class InferShapeContext { ...@@ -213,7 +219,7 @@ class InferShapeContext {
[&](const std::string& sub_name) { [&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name); auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, "MultiOutput(%s:%s) should not be nullptr", name, var, "MultiOutput(%s:%s) should not be nullptr.", name,
sub_name); sub_name);
return var->GetMutable<T>(); return var->GetMutable<T>();
}); });
......
...@@ -65,7 +65,8 @@ static void BuildVar(const std::string& param_name, ...@@ -65,7 +65,8 @@ static void BuildVar(const std::string& param_name,
} }
} }
REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest, REGISTER_OP_WITHOUT_GRADIENT(
test_operator, paddle::framework::OpWithoutKernelTest,
paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker); paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker);
TEST(OperatorBase, all) { TEST(OperatorBase, all) {
...@@ -184,7 +185,8 @@ class CPUKernalMultiInputsTest : public OpKernel { ...@@ -184,7 +185,8 @@ class CPUKernalMultiInputsTest : public OpKernel {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, REGISTER_OP_WITHOUT_GRADIENT(
op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker); paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, REGISTER_OP_CPU_KERNEL(op_with_kernel,
paddle::framework::CPUKernelTest<float, float>); paddle::framework::CPUKernelTest<float, float>);
...@@ -210,7 +212,8 @@ TEST(OpKernel, all) { ...@@ -210,7 +212,8 @@ TEST(OpKernel, all) {
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
} }
REGISTER_OP(op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest, REGISTER_OP_WITHOUT_GRADIENT(
op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker); paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel, REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel,
paddle::framework::CPUKernalMultiInputsTest); paddle::framework::CPUKernalMultiInputsTest);
......
...@@ -30,8 +30,8 @@ limitations under the License. */ ...@@ -30,8 +30,8 @@ limitations under the License. */
namespace py = pybind11; namespace py = pybind11;
USE_OP(add_two); USE_OP(add_two);
USE_CPU_OP(onehot_cross_entropy); USE_CPU_ONLY_OP(onehot_cross_entropy);
USE_NO_GRAD_OP(sgd); USE_OP(sgd);
USE_OP(mul); USE_OP(mul);
USE_OP(mean); USE_OP(mean);
USE_OP(sigmoid); USE_OP(sigmoid);
...@@ -160,13 +160,16 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -160,13 +160,16 @@ All parameter, weight, gradient are variables in Paddle.
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
m.def("get_all_op_protos", []() -> std::vector<py::bytes> { m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
auto &protos = OpProtos(); auto &op_info_map = OpRegistry::op_info_map();
std::vector<py::bytes> ret_values; std::vector<py::bytes> ret_values;
for (auto it = protos.begin(); it != protos.end(); ++it) { for (auto it = op_info_map.begin(); it != op_info_map.end(); ++it) {
PADDLE_ENFORCE(it->second.IsInitialized(), const OpProto *proto = it->second.proto_;
"OpProto must all be initialized"); if (proto == nullptr) {
continue;
}
PADDLE_ENFORCE(proto->IsInitialized(), "OpProto must all be initialized");
std::string str; std::string str;
PADDLE_ENFORCE(it->second.SerializeToString(&str), PADDLE_ENFORCE(proto->SerializeToString(&str),
"Serialize OpProto Error. This could be a bug of Paddle."); "Serialize OpProto Error. This could be a bug of Paddle.");
ret_values.push_back(py::bytes(str)); ret_values.push_back(py::bytes(str));
} }
......
...@@ -44,6 +44,8 @@ endfunction() ...@@ -44,6 +44,8 @@ endfunction()
add_subdirectory(math) add_subdirectory(math)
cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_library(net_op SRCS net_op.cc DEPS op_registry) cc_library(net_op SRCS net_op.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
......
...@@ -57,8 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel { ...@@ -57,8 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker); REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad, ops::AddOpGrad);
REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad);
REGISTER_OP_CPU_KERNEL(add_two, REGISTER_OP_CPU_KERNEL(add_two,
ops::AddKernel<paddle::platform::CPUPlace, float>); ops::AddKernel<paddle::platform::CPUPlace, float>);
...@@ -68,12 +68,11 @@ OnehotCrossEntropy Operator. ...@@ -68,12 +68,11 @@ OnehotCrossEntropy Operator.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
ops::OnehotCrossEntropyOpMaker); ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy, onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>); ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(onehot_cross_entropy, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
onehot_cross_entropy_grad, onehot_cross_entropy_grad,
ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>); ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>);
...@@ -46,7 +46,8 @@ The output will have the same size with input. ...@@ -46,7 +46,8 @@ The output will have the same size with input.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(fill_zeros_like, ops::FillZerosLikeOp, ops::FillZerosLikeOpMaker); REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp,
ops::FillZerosLikeOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fill_zeros_like, fill_zeros_like,
ops::FillZerosLikeKernel<paddle::platform::CPUPlace, float>); ops::FillZerosLikeKernel<paddle::platform::CPUPlace, float>);
...@@ -29,7 +29,7 @@ void CPUGather(const T* params, const int* indices, const int slice_size, ...@@ -29,7 +29,7 @@ void CPUGather(const T* params, const int* indices, const int slice_size,
const int index_size, T* output) { const int index_size, T* output) {
const size_t slice_bytes = slice_size * sizeof(T); const size_t slice_bytes = slice_size * sizeof(T);
for (size_t i = 0; i < index_size; ++i) { for (int i = 0; i < index_size; ++i) {
int index_ = indices[i]; int index_ = indices[i];
memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes); memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes);
} }
...@@ -60,7 +60,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src, ...@@ -60,7 +60,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
// slice size // slice size
int slice_size = 1; int slice_size = 1;
for (size_t i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// Gathering // Gathering
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
......
...@@ -35,7 +35,7 @@ TEST(Gather, GatherData) { ...@@ -35,7 +35,7 @@ TEST(Gather, GatherData) {
p_src = src->mutable_data<int>(make_ddim({3, 4}), CPUPlace()); p_src = src->mutable_data<int>(make_ddim({3, 4}), CPUPlace());
p_index = index->mutable_data<int>(make_ddim({2}), CPUPlace()); p_index = index->mutable_data<int>(make_ddim({2}), CPUPlace());
for (size_t i = 0; i < 12; ++i) p_src[i] = i; for (int i = 0; i < 12; ++i) p_src[i] = i;
p_index[0] = 1; p_index[0] = 1;
p_index[1] = 0; p_index[1] = 0;
...@@ -43,6 +43,6 @@ TEST(Gather, GatherData) { ...@@ -43,6 +43,6 @@ TEST(Gather, GatherData) {
Gather<int>(CPUPlace(), src, index, output); Gather<int>(CPUPlace(), src, index, output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
} }
...@@ -81,5 +81,6 @@ Use to initialize tensor with gaussian random generator. ...@@ -81,5 +81,6 @@ Use to initialize tensor with gaussian random generator.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp,
ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>); REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>);
...@@ -54,9 +54,8 @@ class MeanGradOp : public framework::OperatorWithKernel { ...@@ -54,9 +54,8 @@ class MeanGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker); REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean, REGISTER_OP_CPU_KERNEL(mean,
ops::MeanKernel<paddle::platform::CPUPlace, float>); ops::MeanKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp);
REGISTER_OP_CPU_KERNEL(mean_grad, REGISTER_OP_CPU_KERNEL(mean_grad,
ops::MeanGradKernel<paddle::platform::CPUPlace, float>); ops::MeanGradKernel<paddle::platform::CPUPlace, float>);
...@@ -70,7 +70,5 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -70,7 +70,5 @@ class MulOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker); REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad);
REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
...@@ -20,13 +20,6 @@ class TestOp : public framework::OperatorBase { ...@@ -20,13 +20,6 @@ class TestOp : public framework::OperatorBase {
} }
}; };
class EmptyOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {}
};
template <typename T> template <typename T>
void AssertSameVectorWithoutOrder(const std::vector<T>& expected, void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
const std::vector<T>& actual) { const std::vector<T>& actual) {
...@@ -67,8 +60,8 @@ TEST(OpKernel, all) { ...@@ -67,8 +60,8 @@ TEST(OpKernel, all) {
TEST(NetOp, insert_op) { TEST(NetOp, insert_op) {
NetOp net; NetOp net;
auto op1 = std::shared_ptr<EmptyOp>( auto op1 = std::shared_ptr<framework::NOP>(
new EmptyOp("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}},
{{"Out", {"y"}}}, {})); {{"Out", {"y"}}}, {}));
net.AddOp(op1); net.AddOp(op1);
net.InsertOp(0, op1); net.InsertOp(0, op1);
......
...@@ -246,5 +246,6 @@ RecurrentGradientOp::RecurrentGradientOp( ...@@ -246,5 +246,6 @@ RecurrentGradientOp::RecurrentGradientOp(
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(recurrent_op, paddle::operators::RecurrentOp, REGISTER_OP_WITHOUT_GRADIENT(
recurrent_op, paddle::operators::RecurrentOp,
paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker); paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker);
...@@ -54,6 +54,7 @@ for i in xrange(X.shape[0]): ...@@ -54,6 +54,7 @@ for i in xrange(X.shape[0]):
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker); REGISTER_OP_WITHOUT_GRADIENT(rowwise_add, ops::RowWiseAddOp,
ops::RowWiseAddOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
rowwise_add, ops::RowWiseAddKernel<paddle::platform::CPUPlace, float>); rowwise_add, ops::RowWiseAddKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cstring>
#include "paddle/framework/ddim.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
// Implementation of CPU copy
template <typename T>
void CPUScatterUpdate(const paddle::framework::Tensor* src, const int* index,
const size_t index_size,
paddle::framework::Tensor* output) {
paddle::framework::DDim output_dims = output->dims();
for (size_t i = 0; i < index_size; ++i) {
int index_ = index[i];
paddle::framework::Tensor src_ = *src;
paddle::framework::Tensor output_ = *output;
if (index_size > 1) src_ = src->Slice<T>(i, i + 1);
if (output_dims[0] > 1) output_ = output->Slice<T>(index_, index_ + 1);
auto X = EigenVector<T>::Flatten(src_);
auto Y = EigenVector<T>::Flatten(output_);
Y = X + Y;
}
}
// Implementation of GPU scatter:
template <typename T>
void GPUScatterUpdate(const T* src, const int* index, const int slice_size,
const int index_size, T* output);
/**
* Return a updated tensor from source tensor, scattered according to index:
* dst[i] += src[index[i]]
* input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D)
* return: output tensor
*/
template <typename T>
void ScatterUpdate(const platform::Place& place,
const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
// check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0];
auto src_dims = src->dims();
auto dst_dims = output->dims();
// check src shape and dst shape should match
for (int i = 1; i < src_dims.size(); i++)
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]);
// slice size
size_t slice_size = 1;
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
if (platform::is_cpu_place(place)) {
CPUScatterUpdate<T>(src, index->data<int>(), index_size, output);
} else {
}
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/scatter.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"
#include <gtest/gtest.h>
#include <iostream>
#include <string>
TEST(scatter, ScatterUpdate) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::operators;
Tensor* src = new Tensor();
Tensor* index = new Tensor();
Tensor* output = new Tensor();
float* p_src = nullptr;
int* p_index = nullptr;
p_src = src->mutable_data<float>(make_ddim({1, 4}), CPUPlace());
p_index = index->mutable_data<int>(make_ddim({1}), CPUPlace());
for (size_t i = 0; i < 4; ++i) p_src[i] = float(i);
p_index[0] = 1;
float* p_output = output->mutable_data<float>(make_ddim({4, 4}), CPUPlace());
ScatterUpdate<float>(CPUPlace(), src, index, output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0));
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], float(0));
for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], float(i - 4));
for (size_t i = 4; i < 8; ++i)
EXPECT_EQ(output->data<float>()[i], float(i - 4));
for (size_t i = 8; i < 16; ++i) EXPECT_EQ(p_output[i], float(0));
for (size_t i = 8; i < 16; ++i) EXPECT_EQ(output->data<float>()[i], float(0));
}
...@@ -51,6 +51,6 @@ param_out = param - learning_rate * grad; ...@@ -51,6 +51,6 @@ param_out = param - learning_rate * grad;
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sgd, ops::SGDOp, ops::SGDOpMaker); REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker);
REGISTER_OP_CPU_KERNEL(sgd, REGISTER_OP_CPU_KERNEL(sgd,
ops::SGDOpKernel<paddle::platform::CPUPlace, float>); ops::SGDOpKernel<paddle::platform::CPUPlace, float>);
...@@ -52,9 +52,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { ...@@ -52,9 +52,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker); REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad,
REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad); ops::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid, REGISTER_OP_CPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::CPUPlace, float>); ops::SigmoidKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -62,9 +62,9 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -62,9 +62,9 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad,
ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(softmax, REGISTER_OP_CPU_KERNEL(softmax,
ops::SoftmaxKernel<paddle::platform::CPUPlace, float>); ops::SoftmaxKernel<paddle::platform::CPUPlace, float>);
REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
softmax_grad, ops::SoftmaxGradKernel<paddle::platform::CPUPlace, float>); softmax_grad, ops::SoftmaxGradKernel<paddle::platform::CPUPlace, float>);
...@@ -81,7 +81,7 @@ Used to initialize tensor with uniform random generator. ...@@ -81,7 +81,7 @@ Used to initialize tensor with uniform random generator.
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP(uniform_random, paddle::operators::UniformRandomOp, REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker); paddle::operators::UniformRandomOpMaker);
REGISTER_OP_CPU_KERNEL(uniform_random, REGISTER_OP_CPU_KERNEL(uniform_random,
paddle::operators::CPUUniformRandomKernel<float>); paddle::operators::CPUUniformRandomKernel<float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册