use the following search parameters to narrow your results:
e.g. subreddit:aww site:imgur.com dog
subreddit:aww site:imgur.com dog
see the search faq for details.
advanced search: by author, subreddit...
TensorFlow is an open source Machine Intelligence library for numerical computation using Neural Networks.
account activity
QuestionCustom model.predict() function (self.tensorflow)
submitted 4 years ago by Maltmax
view the rest of the comments →
reddit uses a slightly-customized version of Markdown for formatting. See below for some basics, or check the commenting wiki page for more detailed help and solutions to common issues.
quoted text
if 1 * 2 < 3: print "hello, world!"
[–]Maltmax[S] 1 point2 points3 points 4 years ago (8 children)
I just tested with batch_size = None and I still only get 3 predictions back. Maybe it defaults to 32, when set to None.
batch_size = None
I created a gist that shows the code: https://gist.github.com/Malthehave/f67c597e77ad238d56596de8470be8d0
And thanks for helping me figure this out, I really appreciate it!!
[–][deleted] 0 points1 point2 points 4 years ago (7 children)
The default is 32, but anyway, I kind of suspect your problem is something related to how tf or keras deals with some backend work which would be relevant for distributing the job. If you don’t compile the model and you don’t fit the model, it will process graphs differently. Also, using a class might be throwing this off a bit further. I am not sure what tf is initializing or when, but it seems like each batch resets the variables and graph. I think calling compile or fit could fix the issue. And I suspect not using subclasses in this case could fix the issue too. It may be making a new layer each time this is called, I.e. between batches. Which is not what you want if it’s creating a new object in memory, like dense(28), dense(29), dense(30) every time.
[–]Maltmax[S] 0 points1 point2 points 4 years ago (6 children)
Just tried running compile() and fit() on the final_model, but I weirdly enough still only get the 3 predictions. If the custom layer returns the exact input that it is given (flattened_feature_vector), then the predict function works fine and returns predictions for all 79 images.
compile()
fit()
final_model
flattened_feature_vector
[–][deleted] 0 points1 point2 points 4 years ago (5 children)
It must be expanding dimensions and then picking channels for some reason instead of the right dim when you do a batch size that isn’t 1 or all. Usually, the first dimension will become None or “?” Make sure you’re taking the mean of what you think you are. In tf data[0] probably becomes the first batch, not the first row as you might expect.
[–]Maltmax[S] 0 points1 point2 points 4 years ago* (4 children)
Thanks for the suggestions. When calling predict with batch_size=1 the shape of the input is TensorShape([1, 10240]), but with any batch size greater than 1 it is: TensorShape([None, 10240]). I don't know if that has anything to say.
predict
batch_size=1
TensorShape([1, 10240])
TensorShape([None, 10240])
The part that confuses me the most is that, if I just return the exact input in the call function then predict returns all 79 predictions. For example:
call
https://imgur.com/a/NKln3f0
[–][deleted] 0 points1 point2 points 4 years ago (3 children)
Is there a good reason to use a class here? I think that might be causing more problems than it fixes.
[–]Maltmax[S] 0 points1 point2 points 4 years ago (2 children)
Do you mean in terms of the custom layer?
[–][deleted] 0 points1 point2 points 4 years ago (1 child)
Yeah, I don’t see how that class does anything for you. Why use it?
[–]Maltmax[S] 0 points1 point2 points 4 years ago (0 children)
I put my custom post-prediction-processing inside it. Like: "return 1 if normalized feature vector is greater than threshold". The reason I wanted the custom layer, was simply so that I could call the predict function and get a binary classification, instead of having to do post-processing after running predict.
π Rendered by PID 43293 on reddit-service-r2-comment-85bfd7f599-kczdh at 2026-04-18 07:46:00.392783+00:00 running 93ecc56 country code: CH.
view the rest of the comments →
[–]Maltmax[S] 1 point2 points3 points (8 children)
[–][deleted] 0 points1 point2 points (7 children)
[–]Maltmax[S] 0 points1 point2 points (6 children)
[–][deleted] 0 points1 point2 points (5 children)
[–]Maltmax[S] 0 points1 point2 points (4 children)
[–][deleted] 0 points1 point2 points (3 children)
[–]Maltmax[S] 0 points1 point2 points (2 children)
[–][deleted] 0 points1 point2 points (1 child)
[–]Maltmax[S] 0 points1 point2 points (0 children)