问题描述
使用torch.randint(0, 10, 10)
创建张量时出现错误
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-72-cdd6eb2d3416> in <module> ----> 1 torch.randint(0,10,10) TypeError: randint() received an invalid combination of arguments - got (int, int, int), but expected one of: * (int high, tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool requires_grad) * (int low, int high, tuple of ints size, *, torch.Generator generator, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool requires_grad)
原因分析:
函数的参数不对,torch.randint()
一般需要三个参数分别是low、high、size。
我传入的是下面参数,出现错误
torch.randint(0, 10, 10)
解决方案:
将size维度用元组传入
torch.randint(0, 10, (10,))