20 #ifndef TESSERACT_LSTM_TFNETWORK_H_ 21 #define TESSERACT_LSTM_TFNETWORK_H_ 23 #ifdef INCLUDE_TENSORFLOW 31 #include "tensorflow/core/framework/graph.pb.h" 32 #include "tensorflow/core/public/session.h" 36 class TFNetwork :
public Network {
38 explicit TFNetwork(
const STRING& name);
39 virtual ~TFNetwork() =
default;
42 StaticShape InputShape()
const override {
return input_shape_; }
45 StaticShape OutputShape(
const StaticShape& input_shape)
const override {
49 STRING spec()
const override {
return spec_.
c_str(); }
53 int InitFromProtoStr(
const std::string& proto_str);
56 int num_classes()
const {
return output_shape_.depth(); }
67 void Forward(
bool debug,
const NetworkIO& input,
68 const TransposedArray* input_transpose,
69 NetworkScratch* scratch, NetworkIO* output)
override;
74 bool Backward(
bool debug,
const NetworkIO& fwd_deltas,
75 NetworkScratch* scratch,
76 NetworkIO* back_deltas)
override {
77 tprintf(
"Must override Network::DebugWeights for type %d\n", type_);
80 void DebugWeights()
override {
81 tprintf(
"Must override Network::DebugWeights for type %d\n", type_);
89 StaticShape input_shape_;
91 StaticShape output_shape_;
93 std::unique_ptr<tensorflow::Session> session_;
95 TFNetworkModel model_proto_;
100 #endif // ifdef INCLUDE_TENSORFLOW 102 #endif // TESSERACT_TENSORFLOW_TFNETWORK_H_
bool Serialize(FILE *fp, const char *data, size_t n)
DLLSYM void tprintf(const char *format,...)
bool DeSerialize(FILE *fp, char *data, size_t n)
const char * c_str() const