.
This commit is contained in:
11
network2.py
11
network2.py
@@ -98,7 +98,7 @@ class Model(object):
|
||||
return 2.0 * (self.f(x) - y)
|
||||
|
||||
def dLdb3(self, x, y):
|
||||
return self.dLdf(x, y) * np.ones(self.b3.shape)
|
||||
return self.dLdf(x, y)
|
||||
|
||||
def dLdw3(self, x, y):
|
||||
return self.dLdf(x, y) * np.sum(self.a2(x))
|
||||
@@ -114,17 +114,18 @@ class Model(object):
|
||||
def dLdb2(self, x, y):
|
||||
return self.dLdf(x, y) * self.dfdb2(x)
|
||||
|
||||
def dz2dw2(self, x):
|
||||
return np.sum(self.a2(x))
|
||||
def dz2dw2(self, x): # how z2 changes with a row of w2
|
||||
return np.sum(self.a1(x))
|
||||
|
||||
def da2dw2(self, x):
|
||||
return self.dh(self.z2(x)) * self.dz2dw2(x)
|
||||
|
||||
def dfdw2(self, x):
|
||||
return np.dot(self.w3, self.da2dw2(x))
|
||||
# print self.dfdz2(x).shape
|
||||
return np.dot(self.dfdz2(x), self.dz2dw2(x))
|
||||
|
||||
def dLdw2(self, x, y):
|
||||
return self.dLdf(x, y) * self.dfdw2(x)
|
||||
return self.dLdf(x, y) * np.sum(self.dfdw2(x))
|
||||
|
||||
# First layer updates
|
||||
|
||||
|
Reference in New Issue
Block a user