提交 d3f219aa 编写于 作者: Y Yu Yang

Change IdentityOp to ScaleOp

上级 c108d610
......@@ -56,5 +56,5 @@ cc_library(paddle_pybind SHARED
uniform_random_op
gaussian_random_op
fill_zeros_like_op
identity_op)
scale_op)
endif(WITH_PYTHON)
......@@ -42,7 +42,8 @@ USE_OP(fill_zeros_like);
USE_OP_ITSELF(recurrent_op);
USE_OP(gaussian_random);
USE_OP(uniform_random);
USE_OP(identity);
USE_OP(scale);
USE_OP_ITSELF(identity);
namespace paddle {
namespace framework {
......
......@@ -105,7 +105,10 @@ class Tensor {
template <typename T>
inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
platform::Place place() const { return holder_->place(); }
platform::Place place() const {
PADDLE_ENFORCE_NOT_NULL(holder_, "Tensor get place() must contains holder");
return holder_->place();
}
private:
template <typename T>
......
......@@ -68,4 +68,4 @@ op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op)
op_library(uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu)
op_library(identity_op SRCS identity_op.cc identity_op.cu DEPS net_op)
op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op)
......@@ -68,10 +68,15 @@ std::string NetOp::DebugString() const {
bool NetOp::IsNetOp() const { return true; }
std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
std::vector<std::string> all;
for (auto& pair : this->outputs_) {
for (auto& var_name : pair.second) {
all.push_back(var_name);
}
}
if (has_intermediate) {
return this->outputs_.at(kAll);
return all;
}
auto& all = this->outputs_.at(kAll);
std::vector<std::string> ret_val;
for (auto& each : all) {
if (!Contains(intermediate_outputs_, each)) {
......
......@@ -12,15 +12,15 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/identity_op.h"
#include "paddle/operators/scale_op.h"
#include "paddle/operators/net_op.h"
namespace paddle {
namespace operators {
class IdentityOp : public framework::OperatorWithKernel {
class ScaleOp : public framework::OperatorWithKernel {
public:
IdentityOp(const std::string &type, const VarNameMap &inputs,
ScaleOp(const std::string &type, const VarNameMap &inputs,
const VarNameMap &outputs, const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
......@@ -32,40 +32,71 @@ class IdentityOp : public framework::OperatorWithKernel {
}
};
class IdentityOpMaker : public framework::OpProtoAndCheckerMaker {
template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IdentityOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of identity operator.").NotInGradient();
AddOutput("Out", "The output tensor of identity operator.").NotInGradient();
AddComment(R"DOC(Identity operator
AddInput("X", "The input tensor of scale operator.").NotInGradient();
AddOutput("Out", "The output tensor of scale operator.").NotInGradient();
AddComment(R"DOC(Scale operator
The equation is: Out = X
The equation is: Out = scale*X
)DOC");
AddAttr<AttrType>("scale", "scale of scale operator.").SetDefault(1.0);
}
};
// Identity Op's gradient is identity op, too.
// Grad(Out=identity_op(X)) => Grad(X) = identity_op(Grad(Out))
class IdentityGradOp : public NetOp {
// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out))
template <typename AttrType>
class ScaleGradOp : public NetOp {
public:
IdentityGradOp(const std::string &type, const VarNameMap &inputs,
const VarNameMap &outputs,
const framework::AttributeMap &attrs)
ScaleGradOp(const std::string &type, const VarNameMap &inputs,
const VarNameMap &outputs, const framework::AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
AddOp(framework::OpRegistry::CreateOp(
"identity", {{"X", {Input(framework::GradVarName("Out"))}}},
{{"Out", {Output(framework::GradVarName("X"))}}}, {}));
"scale", {{"X", {Input(framework::GradVarName("Out"))}}},
{{"Out", {Output(framework::GradVarName("X"))}}},
{{"scale", GetAttr<AttrType>("scale")}}));
CompleteAddOp(false);
}
};
// identity is a alias of scale op. This is also a example for creating a alias
// operator.
template <typename AttrType>
class IdentityOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IdentityOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "input tensor of identity op");
AddOutput("Out", "output tensor of identity op");
AddComment("identity operator. Just a alias of scale op which scale = 1.0");
}
};
template <typename AttrType>
class IdentityOp : public NetOp {
public:
IdentityOp(const std::string &type, const VarNameMap &inputs,
const VarNameMap &outputs, const framework::AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
AddOp(framework::OpRegistry::CreateOp(
"scale", {{"X", {Input("X")}}}, {{"Out", {Output("Out")}}},
{{"scale", static_cast<AttrType>(1)}}));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(identity, ops::IdentityOp, ops::IdentityOpMaker, identity_grad,
ops::IdentityGradOp);
REGISTER_OP_CPU_KERNEL(identity, ops::IdentityKernel<float>);
REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker<float>, scale_grad,
ops::ScaleGradOp<float>);
REGISTER_OP_CPU_KERNEL(scale,
ops::ScaleKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_WITHOUT_GRADIENT(identity, ops::IdentityOp<float>,
ops::IdentityOpMaker<float>);
......@@ -12,6 +12,7 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/identity_op.h"
#include "paddle/operators/scale_op.h"
REGISTER_OP_GPU_KERNEL(identity, paddle::operators::IdentityKernel<float>);
REGISTER_OP_GPU_KERNEL(
scale, paddle::operators::ScaleKernel<paddle::platform::GPUPlace, float>);
......@@ -14,17 +14,25 @@
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h"
namespace paddle {
namespace operators {
template <typename T>
class IdentityKernel : public framework::OpKernel {
template <typename Place, typename T, typename AttrType = T>
class ScaleKernel : public framework::OpKernel {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
auto* tensor = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
tensor->CopyFrom<T>(*in, in->place());
tensor->mutable_data<T>(in->place());
auto scale = static_cast<T>(context.op_.GetAttr<AttrType>("scale"));
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& dev = context.GetEigenDevice<Place>();
eigen_out.device(dev) = scale * eigen_in;
}
};
......
......@@ -27,4 +27,4 @@ py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_sgd_op SRCS test_sgd_op.py)
py_test(test_gradient_checker SRCS test_gradient_checker.py)
py_test(test_identity_op SRCS test_identity_op.py)
py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py)
......@@ -160,8 +160,13 @@ class GradientChecker(unittest.TestCase):
grad_tensor.set(data, place)
# run backward op
for name in backward_op.outputs():
backward_outs = backward_op.outputs()
backward_names = [
item for key in backward_outs for item in backward_outs[key]
]
for name in backward_names:
scope.new_var(name)
backward_op.infer_shape(scope)
backward_op.run(scope, ctx)
......
......@@ -2,6 +2,7 @@ import unittest
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
import numpy as np
from paddle.v2.framework.op import Operator
class IdentityTest(unittest.TestCase):
......@@ -20,5 +21,23 @@ class IdentityGradOpTest(GradientChecker):
self.check_grad(op, inputs, set("X"), "Out")
class ScaleTest(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "scale"
self.inputs = {'X': np.random.random((32, 784)).astype("float32")}
self.attrs = {'scale': -2.3}
self.outputs = {'Out': self.inputs['X'] * self.attrs['scale']}
class ScaleGradTest(GradientChecker):
def test_normal(self):
op = Operator("scale", X="X", Out="Out", scale=3.2)
self.check_grad(op,
{"X": np.random.random((10, 10)).astype("float32")},
set("X"), "Out")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册