获取稀疏张量的非零行
Get nonzeros row of a SparseTensor
我想从 SparseTensor 的一行中获取所有非零值,所以 "m" 是我拥有的稀疏张量对象,而 row 是我想从中获取所有非零值和索引的行。所以我想 return 是 [(index, values)] 对的数组。我希望我能在这个问题上得到一些帮助。
def nonzeros( m, row):
res = []
indices = m.indices
values = m.values
userindices = tf.where(tf.equal(indices[:,0], tf.constant(0, dtype=tf.int64)))
res = tf.map_fn(lambda index:(indices[index][1], values[index]), userindices)
return res
终端中的错误消息
TypeError: Input 'strides' of 'StridedSlice' Op has type int32 that does not match type int64 of argument 'begin'.
编辑:
输入非零值
cm 是 coo_matrix,值为
m = tf.SparseTensor(indices=np.array([row,col]).T,
values=cm.data,
dense_shape=[10, 10])
nonzeros(m, 1)
如果数据是
[[ 0. 1. 0. 0. 0. 0. 0. 0. 0. 1.]
[ 0. 0. 0. 0. 1. 0. 0. 0. 0. 2.]
[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
结果应该是
[index, value]
[4,1]
[9,2]
问题是 lambda 内部的 index
是一个张量,您不能直接使用它来索引例如indices
。您可以改用 tf.gather
。此外,您没有在发布的代码中使用 row
参数。
试试这个:
import tensorflow as tf
import numpy as np
def nonzeros(m, row):
indices = m.indices
values = m.values
userindices = tf.where(tf.equal(indices[:, 0], row))
found_idx = tf.gather(indices, userindices)[:, 0, 1]
found_vals = tf.gather(values, userindices)[:, 0:1]
res = tf.concat(1, [tf.expand_dims(tf.cast(found_idx, tf.float64), -1), found_vals])
return res
data = np.array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 2.]])
m = tf.SparseTensor(indices=np.array([[0, 1], [0, 9], [1, 4], [1, 9]]),
values=np.array([1.0, 1.0, 1.0, 2.0]),
shape=[2, 10])
with tf.Session() as sess:
result = nonzeros(m, 1)
print(sess.run(result))
打印:
[[ 4. 1.]
[ 9. 2.]]
我想从 SparseTensor 的一行中获取所有非零值,所以 "m" 是我拥有的稀疏张量对象,而 row 是我想从中获取所有非零值和索引的行。所以我想 return 是 [(index, values)] 对的数组。我希望我能在这个问题上得到一些帮助。
def nonzeros( m, row):
res = []
indices = m.indices
values = m.values
userindices = tf.where(tf.equal(indices[:,0], tf.constant(0, dtype=tf.int64)))
res = tf.map_fn(lambda index:(indices[index][1], values[index]), userindices)
return res
终端中的错误消息
TypeError: Input 'strides' of 'StridedSlice' Op has type int32 that does not match type int64 of argument 'begin'.
编辑: 输入非零值 cm 是 coo_matrix,值为
m = tf.SparseTensor(indices=np.array([row,col]).T,
values=cm.data,
dense_shape=[10, 10])
nonzeros(m, 1)
如果数据是
[[ 0. 1. 0. 0. 0. 0. 0. 0. 0. 1.]
[ 0. 0. 0. 0. 1. 0. 0. 0. 0. 2.]
[ 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]
结果应该是
[index, value]
[4,1]
[9,2]
问题是 lambda 内部的 index
是一个张量,您不能直接使用它来索引例如indices
。您可以改用 tf.gather
。此外,您没有在发布的代码中使用 row
参数。
试试这个:
import tensorflow as tf
import numpy as np
def nonzeros(m, row):
indices = m.indices
values = m.values
userindices = tf.where(tf.equal(indices[:, 0], row))
found_idx = tf.gather(indices, userindices)[:, 0, 1]
found_vals = tf.gather(values, userindices)[:, 0:1]
res = tf.concat(1, [tf.expand_dims(tf.cast(found_idx, tf.float64), -1), found_vals])
return res
data = np.array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 1.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 2.]])
m = tf.SparseTensor(indices=np.array([[0, 1], [0, 9], [1, 4], [1, 9]]),
values=np.array([1.0, 1.0, 1.0, 2.0]),
shape=[2, 10])
with tf.Session() as sess:
result = nonzeros(m, 1)
print(sess.run(result))
打印:
[[ 4. 1.]
[ 9. 2.]]