Add patience, save figure
This commit is contained in:
17
network.py
17
network.py
@@ -184,18 +184,18 @@ def evaluate(model, samples):
|
||||
TRAIN_DATA, TEST_DATA = dataset_get_sin()
|
||||
# TRAIN_DATA, TEST_DATA = dataset_get_linear()
|
||||
|
||||
MODEL = Model(60, sigmoid, d_sigmoid, DATA_TYPE)
|
||||
MODEL = Model(6, sigmoid, d_sigmoid, DATA_TYPE)
|
||||
# MODEL = Model(10, relu, d_relu, DATA_TYPE)
|
||||
|
||||
# Train the model with some training data
|
||||
TRAINING_ITERS = 5000
|
||||
LEARNING_RATE = 0.005
|
||||
TRAINING_SUBSET_SIZE = len(TRAIN_DATA)
|
||||
PATIENCE = 100
|
||||
|
||||
print TRAINING_SUBSET_SIZE
|
||||
|
||||
best_rate = np.inf
|
||||
rates = [["iter", "training_rate", "test_rate"]]
|
||||
for training_iter in range(TRAINING_ITERS):
|
||||
# Create a training sample
|
||||
training_subset_indices = npr.choice(
|
||||
@@ -219,16 +219,21 @@ for training_iter in range(TRAINING_ITERS):
|
||||
# Evaluate accuracy against training data
|
||||
training_rate = evaluate(MODEL, training_subset)
|
||||
test_rate = evaluate(MODEL, TEST_DATA)
|
||||
rates += [[training_iter, training_rate, test_rate]]
|
||||
|
||||
print training_iter, "positive rates:", training_rate, test_rate,
|
||||
print training_iter, "cost:", training_rate, test_rate,
|
||||
|
||||
# If it's the best one so far, store it
|
||||
if training_rate < best_rate:
|
||||
print "(new best)"
|
||||
best_rate = training_rate
|
||||
patience = PATIENCE
|
||||
else:
|
||||
print ""
|
||||
patience -= 1
|
||||
print patience
|
||||
|
||||
if patience <= 0:
|
||||
print PATIENCE, "iterations without improvement"
|
||||
break
|
||||
|
||||
TEST_OUTPUT = np.vectorize(MODEL.f)(TEST_DATA[:, 0])
|
||||
TRAIN_OUTPUT = np.vectorize(MODEL.f)(TRAIN_DATA[:, 0])
|
||||
@@ -240,5 +245,5 @@ scatter_train_out, = plt.plot(
|
||||
scatter_test_out, = plt.plot(
|
||||
TEST_DATA[:, 0], TEST_OUTPUT, 'bo', label="Network output on test data")
|
||||
plt.legend(handles=[scatter_train, scatter_train_out, scatter_test_out])
|
||||
|
||||
plt.savefig("results.png", bbox_inches="tight")
|
||||
plt.show()
|
||||
|
Reference in New Issue
Block a user