import numpy as np
a=np.random.random((4,6))
lens= np.array([1,3,4,2])
idx_sorted = np.argsort(-lens) # in descreasing order
idx_unsorted = np.argsort(idx_sorted) # in increasing order
print(a[idx_sorted][idx_unsorted]==a)
### pytorch version
def __rnn_pack_scan(self, sentence_embs, input_lengths, rnn, h0=None):
'''
sort the length and unsorted
:param sentence_embs: batch x seq_len x emb_size
:param input_lengths: batch
:param rnn: RNN
:param h0: (num_layers * num_directions, batch, hidden_size)
:return:
'''
# sort the input
len_sorted, indices = torch.sort(input_lengths, 0, descending=True)
sentence_embs_sorted = torch.index_select(sentence_embs, 0, indices)
# run
packed = torch.nn.utils.rnn.pack_padded_sequence(sentence_embs_sorted, len_sorted.tolist(), batch_first=True)
outputs, hidden = rnn(packed, h0) # hn = num_layers * num_directions, batch, hidden_size
# output = (batch, seq, feature)
outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
# unsorted
_, idx_unsorted = torch.sort(indices, 0)
outputs_unsorted = torch.index_select(outputs,0,idx_unsorted)
hidden_unsorted = torch.index_select(hidden,1,idx_unsorted)
return outputs_unsorted, hidden_unsorted # (max_sen_len+1) x b x 2h'''