diff --git a/network.py b/network.py index bb2a49d..3c1aade 100644 --- a/network.py +++ b/network.py @@ -94,8 +94,8 @@ class Model(object): def dLdb2(self, x, y): return self.dLdf(x, y) * self.dfdb2() - def dfdw2(self, x): - return np.sum(self.a(x)) + def dfdw2(self, x): # how each entry of f changes wrt each entry of w2 + return self.a(x) def dfda(self): # how f changes with ith element of a return self.w2 @@ -106,10 +106,10 @@ class Model(object): def dLdz1(self, x, y): """Compute dL/dz1 for an input x and expected output y""" - return self.dLdf(x, y) * np.sum(self.dfda() * self.dadz1(x)) + return self.dLdf(x, y) * np.dot(self.dfda(), self.dadz1(x)) def dz1dw1(self, x): - return x + return x * self.w1 def dLdw1(self, x, y): """Compute dL/dw1 for an input x and expected output y"""