aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/demo_feast_wrapper.py37
-rw-r--r--python/import_data.py18
2 files changed, 29 insertions, 26 deletions
diff --git a/python/demo_feast_wrapper.py b/python/demo_feast_wrapper.py
index 6642ca8..4c965e3 100644
--- a/python/demo_feast_wrapper.py
+++ b/python/demo_feast_wrapper.py
@@ -1,36 +1,21 @@
#!/usr/bin/env python
import feast
import numpy as np
+import import_data
-##################################################################
-##################################################################
-##################################################################
-def read_digits(fname='digit.txt'):
- '''
- read_digits(fname='digit.txt')
-
- read a data file that contains the features and class labels.
- each row of the file is a feature vector with the class
- label appended.
- '''
- import csv
-
- fw = csv.reader(open(fname,'rb'), delimiter='\t')
- data = []
- for line in fw:
- data.append( [float(x) for x in line] )
- data = np.array(data)
- labels = data[:,len(data.transpose())-1]
- data = data[:,:len(data.transpose())-1]
- return data, labels
-##################################################################
-##################################################################
-##################################################################
+
+print '---> Loading digit data'
+
+data_source = 'uniform'
+
+
+if data_source == 'uniform':
+ data, labels = import_data.uniform_data()
+elif data_source == 'digits':
+ data, labels = import_data.read_digits('digit.txt')
-print '---> Loading digit data'
-data, labels = read_digits('digit.txt')
n_observations = len(data) # number of samples in the data set
n_features = len(data.transpose()) # number of features in the data set
n_select = 15 # how many features to select
diff --git a/python/import_data.py b/python/import_data.py
index c97ce7e..6d4bd9e 100644
--- a/python/import_data.py
+++ b/python/import_data.py
@@ -14,6 +14,7 @@ def read_digits(fname='digit.txt'):
label appended.
'''
import csv
+ import numpy as np
fw = csv.reader(open(fname,'rb'), delimiter='\t')
data = []
@@ -34,6 +35,23 @@ def read_digits(fname='digit.txt'):
##################################################################
def uniform_data(n_observations = 1000, n_features = 50, n_relevant = 5):
import numpy as np
+ xmax = 10
+ xmin = 0
+ data = np.random.randint(xmax + 1, size = (n_features, n_observations))
+ labels = np.zeros(n_observations)
+ delta = n_relevant * (xmax - xmin) / 2.0
+
+ for m in range(n_observations):
+ zz = 0.0
+ for k in range(n_relevant):
+ zz += data[k, m]
+ if zz > delta:
+ labels[m] = 1
+ else:
+ labels[m] = 2
+ data = data.transpose()
+
+ return data, labels
##################################################################
##################################################################