Python源码示例:sklearn.datasets.fetch_20newsgroups_vectorized()
示例1
def test_20news_vectorized():
try:
datasets.fetch_20newsgroups(subset='all',
download_if_missing=False)
except IOError:
raise SkipTest("Download 20 newsgroups to run this test")
# test subset = train
bunch = datasets.fetch_20newsgroups_vectorized(subset="train")
assert sp.isspmatrix_csr(bunch.data)
assert_equal(bunch.data.shape, (11314, 130107))
assert_equal(bunch.target.shape[0], 11314)
assert_equal(bunch.data.dtype, np.float64)
# test subset = test
bunch = datasets.fetch_20newsgroups_vectorized(subset="test")
assert sp.isspmatrix_csr(bunch.data)
assert_equal(bunch.data.shape, (7532, 130107))
assert_equal(bunch.target.shape[0], 7532)
assert_equal(bunch.data.dtype, np.float64)
# test return_X_y option
fetch_func = partial(datasets.fetch_20newsgroups_vectorized, subset='test')
check_return_X_y(bunch, fetch_func)
# test subset = all
bunch = datasets.fetch_20newsgroups_vectorized(subset='all')
assert sp.isspmatrix_csr(bunch.data)
assert_equal(bunch.data.shape, (11314 + 7532, 130107))
assert_equal(bunch.target.shape[0], 11314 + 7532)
assert_equal(bunch.data.dtype, np.float64)
示例2
def load_20newsgroup_vectorized(folder=SCIKIT_LEARN_DATA, one_hot=True, partitions_proportions=None,
shuffle=False, binary_problem=False, as_tensor=True, minus_value=-1.):
data_train = sk_dt.fetch_20newsgroups_vectorized(data_home=folder, subset='train')
data_test = sk_dt.fetch_20newsgroups_vectorized(data_home=folder, subset='test')
X_train = data_train.data
X_test = data_test.data
y_train = data_train.target
y_test = data_test.target
if binary_problem:
y_train[data_train.target < 10] = minus_value
y_train[data_train.target >= 10] = 1.
y_test[data_test.target < 10] = minus_value
y_test[data_test.target >= 10] = 1.
if one_hot:
y_train = to_one_hot_enc(y_train)
y_test = to_one_hot_enc(y_test)
# if shuffle and sk_shuffle:
# xtr = X_train.tocoo()
# xts = X_test.tocoo()
d_train = Dataset(data=X_train,
target=y_train, info={'target names': data_train.target_names})
d_test = Dataset(data=X_test,
target=y_test, info={'target names': data_train.target_names})
res = [d_train, d_test]
if partitions_proportions:
res = redivide_data([d_train, d_test], partition_proportions=partitions_proportions, shuffle=False)
if as_tensor: [dat.convert_to_tensor() for dat in res]
return Datasets.from_list(res)
示例3
def test_20news_vectorized():
try:
datasets.fetch_20newsgroups(subset='all',
download_if_missing=False)
except IOError:
raise SkipTest("Download 20 newsgroups to run this test")
# test subset = train
bunch = datasets.fetch_20newsgroups_vectorized(subset="train")
assert_true(sp.isspmatrix_csr(bunch.data))
assert_equal(bunch.data.shape, (11314, 130107))
assert_equal(bunch.target.shape[0], 11314)
assert_equal(bunch.data.dtype, np.float64)
# test subset = test
bunch = datasets.fetch_20newsgroups_vectorized(subset="test")
assert_true(sp.isspmatrix_csr(bunch.data))
assert_equal(bunch.data.shape, (7532, 130107))
assert_equal(bunch.target.shape[0], 7532)
assert_equal(bunch.data.dtype, np.float64)
# test subset = all
bunch = datasets.fetch_20newsgroups_vectorized(subset='all')
assert_true(sp.isspmatrix_csr(bunch.data))
assert_equal(bunch.data.shape, (11314 + 7532, 130107))
assert_equal(bunch.target.shape[0], 11314 + 7532)
assert_equal(bunch.data.dtype, np.float64)
示例4
def get_mldata(dataset):
# Use scikit to grab datasets and save them save_dir.
save_dir = FLAGS.save_dir
filename = os.path.join(save_dir, dataset[1]+'.pkl')
if not gfile.Exists(save_dir):
gfile.MkDir(save_dir)
if not gfile.Exists(filename):
if dataset[0][-3:] == 'csv':
data = get_csv_data(dataset[0])
elif dataset[0] == 'breast_cancer':
data = load_breast_cancer()
elif dataset[0] == 'iris':
data = load_iris()
elif dataset[0] == 'newsgroup':
# Removing header information to make sure that no newsgroup identifying
# information is included in data
data = fetch_20newsgroups_vectorized(subset='all', remove=('headers'))
tfidf = TfidfTransformer(norm='l2')
X = tfidf.fit_transform(data.data)
data.data = X
elif dataset[0] == 'rcv1':
sklearn.datasets.rcv1.URL = (
'http://www.ai.mit.edu/projects/jmlr/papers/'
'volume5/lewis04a/a13-vector-files/lyrl2004_vectors')
sklearn.datasets.rcv1.URL_topics = (
'http://www.ai.mit.edu/projects/jmlr/papers/'
'volume5/lewis04a/a08-topic-qrels/rcv1-v2.topics.qrels.gz')
data = sklearn.datasets.fetch_rcv1(
data_home='/tmp')
elif dataset[0] == 'wikipedia_attack':
data = get_wikipedia_talk_data()
elif dataset[0] == 'cifar10':
data = get_cifar10()
elif 'keras' in dataset[0]:
data = get_keras_data(dataset[0])
else:
try:
data = fetch_mldata(dataset[0])
except:
raise Exception('ERROR: failed to fetch data from mldata.org')
X = data.data
y = data.target
if X.shape[0] != y.shape[0]:
X = np.transpose(X)
assert X.shape[0] == y.shape[0]
data = {'data': X, 'target': y}
pickle.dump(data, gfile.GFile(filename, 'w'))