Fix preprocessing.text.Tokenizer
This commit is contained in:
parent
021d11e6cc
commit
b365fe450d
@ -59,7 +59,7 @@ class Tokenizer(object):
|
|||||||
self.word_index = dict(zip(sorted_voc, range(len(sorted_voc))))
|
self.word_index = dict(zip(sorted_voc, range(len(sorted_voc))))
|
||||||
|
|
||||||
self.index_docs = {}
|
self.index_docs = {}
|
||||||
for w, c in self.word_docs:
|
for w, c in self.word_docs.items():
|
||||||
self.index_docs[self.word_index[w]] = c
|
self.index_docs[self.word_index[w]] = c
|
||||||
|
|
||||||
|
|
||||||
@ -102,9 +102,9 @@ class Tokenizer(object):
|
|||||||
|
|
||||||
def texts_to_matrix(self, texts, mode="binary"):
|
def texts_to_matrix(self, texts, mode="binary"):
|
||||||
'''
|
'''
|
||||||
modes: binary, count, tfidf
|
modes: binary, count, tfidf, freq
|
||||||
'''
|
'''
|
||||||
sequences = self.to_sequences(texts)
|
sequences = self.texts_to_sequences(texts)
|
||||||
return self.sequences_to_matrix(sequences, mode=mode)
|
return self.sequences_to_matrix(sequences, mode=mode)
|
||||||
|
|
||||||
def sequences_to_matrix(self, sequences, mode="binary"):
|
def sequences_to_matrix(self, sequences, mode="binary"):
|
||||||
@ -112,17 +112,23 @@ class Tokenizer(object):
|
|||||||
modes: binary, count, tfidf, freq
|
modes: binary, count, tfidf, freq
|
||||||
'''
|
'''
|
||||||
if not self.nb_words:
|
if not self.nb_words:
|
||||||
raise Exception("Specify a dimension (nb_words argument")
|
if self.word_index:
|
||||||
|
nb_words = len(self.word_index)
|
||||||
|
else:
|
||||||
|
raise Exception("Specify a dimension (nb_words argument), or fit on some text data first")
|
||||||
|
else:
|
||||||
|
nb_words = self.nb_words
|
||||||
|
|
||||||
if mode == "tfidf" and not self.document_count:
|
if mode == "tfidf" and not self.document_count:
|
||||||
raise Exception("Fit the Tokenizer on some data before using tfidf mode")
|
raise Exception("Fit the Tokenizer on some data before using tfidf mode")
|
||||||
|
|
||||||
X = np.zeros((len(sequences), self.nb_words))
|
X = np.zeros((len(sequences), nb_words))
|
||||||
for i, seq in enumerate(sequences):
|
for i, seq in enumerate(sequences):
|
||||||
if not seq:
|
if not seq:
|
||||||
pass
|
pass
|
||||||
counts = {}
|
counts = {}
|
||||||
for j in seq:
|
for j in seq:
|
||||||
if j >= self.nb_words:
|
if j >= nb_words:
|
||||||
pass
|
pass
|
||||||
if j not in counts:
|
if j not in counts:
|
||||||
counts[j] = 1.
|
counts[j] = 1.
|
||||||
|
Loading…
Reference in New Issue
Block a user