提交 6f0b5820 编写于 作者: M Megvii Engine Team

chore(imperative/amp): adapt dev

GitOrigin-RevId: 41eb0faadf34f4c31fb26b2de7850748666589c5
上级 ee984e86
......@@ -10,8 +10,7 @@ namespace cuda {
#define cb(_dtype) \
INST_REDUCE( \
device_reduce::CheckNonFiniteOp< \
_dtype COMMA dt_float32 COMMA dt_int32 COMMA dt_int32>, \
device_reduce::CheckNonFiniteOp<_dtype COMMA dt_int32 COMMA dt_int32>, \
false);
cb(dt_float32);
......
......@@ -14,7 +14,7 @@ using device_reduce::CheckNonFiniteOp;
template <typename T>
size_t CheckNonFiniteImpl::_get_workspace_in_bytes() {
// Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes
typedef CheckNonFiniteOp<T, dt_float32, dt_int32, dt_int32> Op;
typedef CheckNonFiniteOp<T, dt_int32, dt_int32> Op;
megdnn_assert(m_size > 0);
WorkspaceBundle bundle(
nullptr, {
......@@ -59,7 +59,7 @@ void CheckNonFiniteImpl::_exec(
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(srcs, dst, workspace.size);
typedef CheckNonFiniteOp<T, dt_float32, dt_int32, dt_int32> Op;
typedef CheckNonFiniteOp<T, dt_int32, dt_int32> Op;
auto stream = cuda_stream(this->handle());
SmallVector<size_t> workspace_sizes{
sizeof(T*) * m_size,
......
......@@ -247,4 +247,4 @@ def _override(
def _get_actual_op_param(function_param, config_param):
return function_param if config_param is "default" else config_param
return function_param if config_param == "default" else config_param
......@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is "
+ str(type(param))
)
param._reset(Tensor(param, no_cache=True))
param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))
for name, default in self._defaults.items():
if default is required and name not in param_group:
......
......@@ -581,9 +581,9 @@ ValueRefList FormatTransformation::apply_transformation(
(GenericFunction&)inputs[1].cast<FunctionValue>();
// make param grads as FormattedTensor
GenericFunction new_callback =
[this, callback, format](Span<ValueRef> inputs_) -> ValueRefList {
[&, callback, format](Span<ValueRef> inputs_) -> ValueRefList {
auto wrapped_inputs = SmallVector<ValueRef>{
this->value_type().make(inputs_.item(), format)};
m_value_type.make(inputs_.item(), format)};
auto ret = callback(wrapped_inputs);
return ret;
};
......
......@@ -67,7 +67,6 @@ template <typename T>
class Type : public IType {
protected:
Type(std::string name) : IType(std::move(name)) {}
Type(IType&& type) : IType(std::move(type)) {}
// TODO: each type owns an allocator
public:
......@@ -105,7 +104,6 @@ template <typename T>
class ObjectType : public Type<T> {
public:
ObjectType(std::string name) : Type<T>(name) {}
ObjectType(IType&& type) : Type<T>(std::move(type)) {}
};
/**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册