提交 6f0b5820 编写于 作者: M Megvii Engine Team

chore(imperative/amp): adapt dev

GitOrigin-RevId: 41eb0faadf34f4c31fb26b2de7850748666589c5
上级 ee984e86
...@@ -8,10 +8,9 @@ namespace cuda { ...@@ -8,10 +8,9 @@ namespace cuda {
#define COMMA , #define COMMA ,
#define cb(_dtype) \ #define cb(_dtype) \
INST_REDUCE( \ INST_REDUCE( \
device_reduce::CheckNonFiniteOp< \ device_reduce::CheckNonFiniteOp<_dtype COMMA dt_int32 COMMA dt_int32>, \
_dtype COMMA dt_float32 COMMA dt_int32 COMMA dt_int32>, \
false); false);
cb(dt_float32); cb(dt_float32);
......
...@@ -14,7 +14,7 @@ using device_reduce::CheckNonFiniteOp; ...@@ -14,7 +14,7 @@ using device_reduce::CheckNonFiniteOp;
template <typename T> template <typename T>
size_t CheckNonFiniteImpl::_get_workspace_in_bytes() { size_t CheckNonFiniteImpl::_get_workspace_in_bytes() {
// Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes // Call the _get_workspace_in_bytes to reduce the loop fetch workspace bytes
typedef CheckNonFiniteOp<T, dt_float32, dt_int32, dt_int32> Op; typedef CheckNonFiniteOp<T, dt_int32, dt_int32> Op;
megdnn_assert(m_size > 0); megdnn_assert(m_size > 0);
WorkspaceBundle bundle( WorkspaceBundle bundle(
nullptr, { nullptr, {
...@@ -59,7 +59,7 @@ void CheckNonFiniteImpl::_exec( ...@@ -59,7 +59,7 @@ void CheckNonFiniteImpl::_exec(
_megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst, _megdnn_in const TensorNDArray& srcs, _megdnn_tensor_out dst,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
check_exec(srcs, dst, workspace.size); check_exec(srcs, dst, workspace.size);
typedef CheckNonFiniteOp<T, dt_float32, dt_int32, dt_int32> Op; typedef CheckNonFiniteOp<T, dt_int32, dt_int32> Op;
auto stream = cuda_stream(this->handle()); auto stream = cuda_stream(this->handle());
SmallVector<size_t> workspace_sizes{ SmallVector<size_t> workspace_sizes{
sizeof(T*) * m_size, sizeof(T*) * m_size,
......
...@@ -247,4 +247,4 @@ def _override( ...@@ -247,4 +247,4 @@ def _override(
def _get_actual_op_param(function_param, config_param): def _get_actual_op_param(function_param, config_param):
return function_param if config_param is "default" else config_param return function_param if config_param == "default" else config_param
...@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta): ...@@ -97,7 +97,7 @@ class Optimizer(metaclass=ABCMeta):
"optimizer can only optimize Parameters, but one of the params is " "optimizer can only optimize Parameters, but one of the params is "
+ str(type(param)) + str(type(param))
) )
param._reset(Tensor(param, no_cache=True)) param._reset(Tensor(param.numpy(), no_cache=True, format=param.format))
for name, default in self._defaults.items(): for name, default in self._defaults.items():
if default is required and name not in param_group: if default is required and name not in param_group:
......
...@@ -581,9 +581,9 @@ ValueRefList FormatTransformation::apply_transformation( ...@@ -581,9 +581,9 @@ ValueRefList FormatTransformation::apply_transformation(
(GenericFunction&)inputs[1].cast<FunctionValue>(); (GenericFunction&)inputs[1].cast<FunctionValue>();
// make param grads as FormattedTensor // make param grads as FormattedTensor
GenericFunction new_callback = GenericFunction new_callback =
[this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { [&, callback, format](Span<ValueRef> inputs_) -> ValueRefList {
auto wrapped_inputs = SmallVector<ValueRef>{ auto wrapped_inputs = SmallVector<ValueRef>{
this->value_type().make(inputs_.item(), format)}; m_value_type.make(inputs_.item(), format)};
auto ret = callback(wrapped_inputs); auto ret = callback(wrapped_inputs);
return ret; return ret;
}; };
......
...@@ -67,7 +67,6 @@ template <typename T> ...@@ -67,7 +67,6 @@ template <typename T>
class Type : public IType { class Type : public IType {
protected: protected:
Type(std::string name) : IType(std::move(name)) {} Type(std::string name) : IType(std::move(name)) {}
Type(IType&& type) : IType(std::move(type)) {}
// TODO: each type owns an allocator // TODO: each type owns an allocator
public: public:
...@@ -105,7 +104,6 @@ template <typename T> ...@@ -105,7 +104,6 @@ template <typename T>
class ObjectType : public Type<T> { class ObjectType : public Type<T> {
public: public:
ObjectType(std::string name) : Type<T>(name) {} ObjectType(std::string name) : Type<T>(name) {}
ObjectType(IType&& type) : Type<T>(std::move(type)) {}
}; };
/** /**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册