提交 c1696696 编写于 作者: Q QI JUN 提交者: GitHub

Merge pull request #3882 from QiJune/refactor_op_py_test

Refactor operator python test framework and add sum operator
......@@ -123,6 +123,15 @@ OperatorBase::OperatorBase(const std::string& type,
CheckAllInputOutputSet();
}
std::vector<std::string> OperatorBase::InputVars() const {
std::vector<std::string> ret_val;
for (auto& o : outputs_) {
ret_val.reserve(ret_val.size() + o.second.size());
ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
}
return ret_val;
}
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
std::vector<std::string> ret_val;
if (has_intermediate) {
......
......@@ -94,11 +94,14 @@ class OperatorBase {
const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; }
//! Get a input with argument's name described in `op_proto`
std::string Input(const std::string& name) const;
//! Get a input which has multiple variables.
const std::vector<std::string>& Inputs(const std::string& name) const;
std::vector<std::string> InputVars() const;
//! Get a output with argument's name described in `op_proto`
std::string Output(const std::string& name) const;
//! Get an output which has multiple variables.
......@@ -311,9 +314,9 @@ class InferShapeContext {
}
template <typename T>
std::vector<const T*> MultiOutput(const std::string& name) const {
std::vector<T*> MultiOutput(const std::string& name) const {
auto names = op_.Outputs(name);
std::vector<const T*> res;
std::vector<T*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sum_op.h"
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class SumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto *out = ctx.Output<framework::Tensor>("Out");
int N = ins.size();
auto in_dim = ins[0]->dims();
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
for (int i = 1; i < N; i++) {
auto dim = ins[i]->dims();
PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape");
}
out->Resize(in_dim);
}
};
class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.").AsDuplicable();
AddOutput("Out", "the output tensor of sum operator.");
AddComment(R"DOC(
Sum the input tensors.
)DOC");
}
};
class SumGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto outputs = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
auto dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
for (auto output : outputs) {
output->Resize(dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(sum_grad,
ops::SumGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/sum_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(sum_grad,
ops::SumGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class SumKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto ins = context.MultiInput<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>();
auto result = EigenVector<T>::Flatten(*out);
int N = ins.size();
auto in = EigenVector<T>::Flatten(*(ins[0]));
result.device(place) = in;
for (int i = 1; i < N; i++) {
auto in = EigenVector<T>::Flatten(*(ins[i]));
result.device(place) = result + in;
}
}
};
template <typename Place, typename T>
class SumGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>(framework::GradVarName("Out"));
auto outs = context.MultiOutput<Tensor>(framework::GradVarName("X"));
for (auto out : outs) {
out->mutable_data<T>(context.GetPlace());
}
auto place = context.GetEigenDevice<Place>();
auto in = EigenVector<T>::Flatten(*input);
for (auto out : outs) {
auto result = EigenVector<T>::Flatten(*out);
result.device(place) = in;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -51,6 +51,7 @@ USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter);
USE_OP(top_k);
USE_OP(squared_l2_distance);
USE_OP(sum);
namespace paddle {
namespace framework {
......@@ -216,7 +217,10 @@ All parameter, weight, gradient are variables in Paddle.
-> std::map<std::string, std::vector<std::string>> {
return op.Outputs();
})
.def("output_vars",
[](const OperatorBase &op) { return op.OutputVars(true); })
.def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
.def("input_vars", [](const OperatorBase &op) { return op.InputVars(); })
.def("__str__", &OperatorBase::DebugString)
.def("no_intermediate_outputs",
[](const OperatorBase &op) { return op.OutputVars(false); })
......
......@@ -142,8 +142,8 @@ def create_op_creation_method(op_proto):
return OpInfo(
method=__impl__,
name=op_proto.type,
inputs=[var.name for var in op_proto.inputs],
outputs=[var.name for var in op_proto.outputs],
inputs=[(var.name, var.duplicable) for var in op_proto.inputs],
outputs=[(var.name, var.duplicable) for var in op_proto.outputs],
attrs=[attr.name for attr in op_proto.attrs])
......@@ -180,9 +180,15 @@ class OperatorFactory(object):
return self.op_methods.get(t)
def get_op_input_names(self, type):
return map(lambda x: x[0], self.get_op_info(type).inputs)
def get_op_inputs(self, type):
return self.get_op_info(type).inputs
def get_op_output_names(self, type):
return map(lambda x: x[0], self.get_op_info(type).outputs)
def get_op_outputs(self, type):
return self.get_op_info(type).outputs
def get_op_attr_names(self, type):
......
......@@ -33,5 +33,6 @@ py_test(test_sgd_op SRCS test_sgd_op.py)
py_test(test_gradient_checker SRCS test_gradient_checker.py)
py_test(test_lookup_table SRCS test_lookup_table.py)
py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py)
py_test(test_sum_op SRCS test_sum_op.py)
py_test(mnist SRCS mnist.py)
py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py)
import unittest
import numpy as np
import itertools
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
def grad_var_name(var_name):
return var_name + "@GRAD"
def create_op(scope, op_type, inputs, outputs, attrs=None):
kwargs = dict()
for in_name, in_dup in Operator.get_op_inputs(op_type):
if in_name in inputs:
kwargs[in_name] = []
if in_dup:
sub_in = inputs[in_name]
for sub_in_name in sub_in:
var = scope.new_var(sub_in_name)
kwargs[in_name].append(sub_in_name)
else:
var = scope.new_var(in_name)
kwargs[in_name].append(in_name)
for out_name, out_dup in Operator.get_op_outputs(op_type):
if out_name in outputs:
kwargs[out_name] = []
if out_dup:
sub_in = outputs[out_name]
for sun_in_name in sub_in:
var = scope.new_var(sun_in_name)
kwargs[out_name].append(sun_in_name)
else:
var = scope.new_var(out_name)
kwargs[out_name].append(out_name)
for attr_name in Operator.get_op_attr_names(op_type):
kwargs[attr_name] = attrs[attr_name]
return Operator(op_type, **kwargs)
def set_input(scope, op, inputs, place):
for in_name, in_dup in Operator.get_op_inputs(op.type()):
if in_name in inputs:
if in_dup:
sub_in = inputs[in_name]
for sub_in_name in sub_in:
var = scope.find_var(sub_in_name)
tensor = var.get_tensor()
arr = sub_in[sub_in_name]
tensor.set_dims(arr.shape)
tensor.set(arr, place)
else:
var = scope.find_var(in_name)
tensor = var.get_tensor()
arr = inputs[in_name]
tensor.set_dims(arr.shape)
tensor.set(arr, place)
def set_output_grad(scope, op, outputs, place):
for out_name, out_dup in Operator.get_op_outputs(op.type()):
if out_name in outputs:
if out_dup:
sub_out = outputs[out_name]
for sub_out_name in sub_out:
out_tensor = scope.find_var(sub_out_name).get_tensor()
grad_tensor = scope.new_var(grad_var_name(
sub_out_name)).get_tensor()
grad_tensor.set_dims(out_tensor.shape())
data = np.ones(out_tensor.shape(), dtype=np.float32)
grad_tensor.set(data, place)
else:
out_tensor = scope.find_var(out_name).get_tensor()
grad_tensor = scope.new_var(grad_var_name(out_name)).get_tensor(
)
grad_tensor.set_dims(out_tensor.shape())
data = np.ones(out_tensor.shape(), dtype=np.float32)
grad_tensor.set(data, place)
def get_numeric_gradient(scope,
op,
inputs,
input_to_check,
output_name,
delta=0.005,
in_place=False):
set_input(scope, op, inputs, core.CPUPlace())
op.infer_shape(scope)
tensor_to_check = scope.find_var(input_to_check).get_tensor()
def product(dim):
return reduce(lambda a, b: a * b, dim, 1)
ctx = core.DeviceContext.create(core.CPUPlace())
def get_output():
op.run(scope, ctx)
return np.array(scope.find_var(output_name).get_tensor()).sum()
tensor_to_check = scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.get_dims())
gradient_flat = np.zeros(shape=(tensor_size, ), dtype='float32')
# we only compute gradient of one element each time.
# we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size):
if in_place:
set_input(op, inputs, core.CPUPlace())
# get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i)
# add delta to it, run op and then get the sum of the result tensor.
x_pos = origin + delta
tensor_to_check.set_float_element(i, x_pos)
y_pos = get_output()
if in_place:
set_input(op, inputs, core.CPUPlace())
x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg)
y_neg = get_output()
tensor_to_check.set_float_element(i, origin)
gradient_flat[i] = (y_pos - y_neg) / delta / 2
return gradient_flat.reshape(tensor_to_check.get_dims())
def get_backward_op(scope, op, no_grad_set):
backward_op = core.Operator.backward(op, no_grad_set)
for input in backward_op.input_vars():
var = scope.new_var(input)
var.get_tensor()
for output in backward_op.output_vars():
var = scope.new_var(output)
var.get_tensor()
return backward_op
def get_gradient(scope, op, inputs, outputs, grad_name, place,
no_grad_set=None):
ctx = core.DeviceContext.create(place)
set_input(scope, op, inputs, place)
op.infer_shape(scope)
op.run(scope, ctx)
if no_grad_set is None:
no_grad_set = set()
backward_op = get_backward_op(scope, op, no_grad_set)
set_output_grad(scope, op, outputs, place)
backward_op.infer_shape(scope)
backward_op.run(scope, ctx)
out = np.array(scope.find_var(grad_name).get_tensor())
return out
class OpTest(unittest.TestCase):
def check_output_with_place(self, place):
self.scope = core.Scope()
self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs)
if isinstance(place, core.GPUPlace) and not self.op.support_gpu():
return
set_input(self.scope, self.op, self.inputs, place)
self.op.infer_shape(self.scope)
ctx = core.DeviceContext.create(place)
self.op.run(self.scope, ctx)
for out_name, out_dup in Operator.get_op_outputs(self.op.type()):
if out_dup:
sub_out = self.outputs[out_name]
for sub_out_name in sub_out:
actual = np.array(
self.scope.find_var(sub_out_name).get_tensor())
expect = sub_out[sub_out_name]
self.assertTrue(
np.allclose(
actual, expect, atol=1e-05),
"output name: " + out_name + "has diff")
else:
actual = np.array(self.scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name]
self.assertTrue(
np.allclose(
actual, expect, atol=1e-05),
"output name: " + out_name + "has diff")
def check_output(self):
places = [core.CPUPlace()]
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
for place in places:
self.check_output_with_place(place)
def __assert_is_close(self, numeric_grads, analytic_grads, names,
max_relative_error, msg_prefix):
for a, b, name in itertools.izip(numeric_grads, analytic_grads, names):
abs_a = np.abs(a)
abs_a[abs_a < 1e-3] = 1
diff_mat = np.abs(a - b) / abs_a
max_diff = np.max(diff_mat)
def err_msg():
offset = np.argmax(diff_mat > max_relative_error)
return "%s Variable %s max gradient diff %f over limit %f, the first " \
"error element is %d" % (
msg_prefix, name, max_diff, max_relative_error, offset)
self.assertLessEqual(max_diff, max_relative_error, err_msg())
def check_grad(self,
inputs_to_check,
output_name,
no_grad_set=None,
in_place=False,
max_relative_error=0.005):
self.scope = core.Scope()
self.op = create_op(self.scope, self.op_type, self.inputs, self.outputs)
if no_grad_set is None:
no_grad_set = set()
numeric_grads = [
get_numeric_gradient(
self.scope,
self.op,
self.inputs,
input_to_check,
output_name,
in_place=in_place) for input_to_check in inputs_to_check
]
grad_names = [
grad_var_name(input_to_check) for input_to_check in inputs_to_check
]
cpu_place = core.CPUPlace()
cpu_analytic_grads = [
get_gradient(self.scope, self.op, self.inputs, self.outputs,
grad_name, cpu_place, no_grad_set)
for grad_name in grad_names
]
self.__assert_is_close(numeric_grads, cpu_analytic_grads, grad_names,
max_relative_error,
"Gradient Check On %s" % str(cpu_place))
if core.is_compile_gpu() and self.op.support_gpu():
gpu_place = core.GPUPlace(0)
gpu_analytic_grads = [
get_gradient(self.scope, self.op, self.inputs, self.outputs,
grad_name, gpu_place, no_grad_set)
for grad_name in grad_names
]
self.__assert_is_close(numeric_grads, gpu_analytic_grads,
grad_names, max_relative_error,
"Gradient Check On %s" % str(gpu_place))
for c_grad, g_grad, name in itertools.izip(
cpu_analytic_grads, gpu_analytic_grads, grad_names):
self.assertTrue(
np.allclose(
c_grad, g_grad, atol=1e-4),
"output name: " + name + " has diff")
import unittest
import numpy
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
from op_test import OpTest
class TestCrossEntropy(unittest.TestCase):
__metaclass__ = OpTestMeta
class TestCrossEntropy(OpTest):
def setUp(self):
self.type = "onehot_cross_entropy"
self.op_type = "onehot_cross_entropy"
batch_size = 30
class_num = 10
X = numpy.random.random((batch_size, class_num)).astype("float32")
label = 5 * numpy.ones(batch_size).astype("int32")
X = numpy.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label = (class_num / 2) * numpy.ones(batch_size).astype("int32")
self.inputs = {'X': X, 'label': label}
Y = []
for i in range(0, batch_size):
Y.append(-numpy.log(X[i][label[i]]))
self.outputs = {'Y': numpy.array(Y).astype("float32")}
def test_check_output(self):
self.check_output()
class CrossEntropyGradOpTest(GradientChecker):
def test_check_grad(self):
op = create_op("onehot_cross_entropy")
batch_size = 30
class_num = 10
inputs = {
"X": numpy.random.uniform(
0.1, 1.0, [batch_size, class_num]).astype("float32"),
"label": (class_num / 2) * numpy.ones(batch_size).astype("int32")
}
self.check_grad(op, inputs, set("X"), "Y")
self.check_grad(["X"], "Y")
if __name__ == "__main__":
......
import unittest
import numpy as np
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
from op_test import OpTest
class TestSigmoidOp(unittest.TestCase):
__metaclass__ = OpTestMeta
class TestSigmoid(OpTest):
def setUp(self):
self.type = "sigmoid"
self.inputs = {'X': np.random.random((15, 31)).astype("float32")}
self.op_type = "sigmoid"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))}
def test_check_output(self):
self.check_output()
class TestSigmoidGradOp(GradientChecker):
def test_grad(self):
op = create_op("sigmoid")
inputs = {"X": np.random.uniform(0.1, 1, [11, 17]).astype("float32")}
# compare gpu and cpu results for backward op.
# this test will be skiped if only compiling CPU version.
self.compare_grad(op, inputs)
# check gradients
self.check_grad(op, inputs, set("X"), "Y", max_relative_error=0.007)
def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.007)
if __name__ == '__main__':
......
import unittest
import numpy as np
from op_test import OpTest
class TestSumOp(OpTest):
def setUp(self):
self.op_type = "sum"
x0 = np.random.random((3, 4)).astype('float32')
x1 = np.random.random((3, 4)).astype('float32')
x2 = np.random.random((3, 4)).astype('float32')
self.inputs = {"X": {"x0": x0, "x1": x1, "x2": x2}}
y = x0 + x1 + x2
self.outputs = {'Out': y}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["x0"], "Out")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册