simplify code a bit

This commit is contained in:
Carl Pearson
2016-11-16 17:12:17 -06:00
parent 3968ac642a
commit 4c54bfadd7

View File

@@ -55,7 +55,7 @@ def d_sigmoid(vec):
def L(x, y):
return (x - y) * (x - y)
return 0.5 * (x - y) * (x - y)
class Model(object):
@@ -86,19 +86,14 @@ class Model(object):
return self.h(self.z1(x))
def f(self, x):
return self.w2.dot(self.a(x)) + self.b2
return np.dot(self.w2, self.a(x)) + self.b2
def dLdf(self, x, y):
return -2.0 * (y - self.f(x))
def dfdb2(self):
return np.array([1.0])
return self.f(x) - y
def dLdb2(self, x, y):
return self.dLdf(x, y) * self.dfdb2()
return self.dLdf(x, y)
def dfdw2(self, x): # how each entry of f changes wrt each entry of w2
return self.a(x)
def dfda(self): # how f changes with ith element of a
return self.w2
@@ -111,22 +106,16 @@ class Model(object):
"""Compute dL/dz1 for an input x and expected output y"""
return self.dLdf(x, y) * np.dot(self.dfda(), self.dadz1(x))
def dz1dw1(self, x):
return x * self.w1
def dLdw1(self, x, y):
"""Compute dL/dw1 for an input x and expected output y"""
return self.dLdf(x, y) * np.sum(self.dfda() * self.dadz1(x) * self.dz1dw1(x))
return self.dLdf(x, y) * np.dot(self.dfda(), self.dadz1(x) * x)
def dLdw2(self, x, y):
"""Compute dL/dw2 for an input x and expected output y"""
return self.dLdf(x, y) * self.dfdw2(x)
def dz1db1(self):
return np.ones(self.b1.shape)
return self.dLdf(x, y) * self.a(x) #df/dw2
def dLdb1(self, x, y):
return self.dLdf(x, y) * np.sum(self.dfda() * self.dadz1(x) * self.dz1db1())
return self.dLdf(x, y) * np.dot(self.dfda(), self.dadz1(x))
def backward(self, training_samples, ETA):
for sample in training_samples:
@@ -195,12 +184,12 @@ def evaluate(model, samples):
TRAIN_DATA, TEST_DATA = dataset_get_sin()
# TRAIN_DATA, TEST_DATA = dataset_get_linear()
MODEL = Model(10, sigmoid, d_sigmoid, DATA_TYPE)
MODEL = Model(60, sigmoid, d_sigmoid, DATA_TYPE)
# MODEL = Model(10, relu, d_relu, DATA_TYPE)
# Train the model with some training data
TRAINING_ITERS = 5000
LEARNING_RATE = 0.0005
LEARNING_RATE = 0.005
TRAINING_SUBSET_SIZE = len(TRAIN_DATA)
print TRAINING_SUBSET_SIZE