Tensorflow 源码分析-会话与线程池之间的关系

Posted raintungli

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Tensorflow 源码分析-会话与线程池之间的关系相关的知识,希望对你有一定的参考价值。

1. Tensorflow 的sessionFactory

创建新的会话,tensorflow使用了多工厂模式,在不同的场景下使用不同的工厂, 使用什么工厂模式由传递进来的SessionOptions来决定。

1.1 注册工厂

Tensorflow 提供了可以注册多会话工厂的模式,允许不同的模块注册自己的会话工厂

void SessionFactory::Register(const string& runtime_type,
                              SessionFactory* factory) 
  mutex_lock l(*get_session_factory_lock());
  if (!session_factories()->insert(runtime_type, factory).second) 
    LOG(ERROR) << "Two session factories are being registered "
               << "under" << runtime_type;
  

默认的tensorflow提供了两种factoy,一个是DirectSession单机,一个是GrpcSession集群。

使用什么factory由传递的sessionoptions的target来决定


2. Tensorflow 的session

2.1 初始化Session

在session.cc代码中,通过NewSession来初始化会话

Session* NewSession(const SessionOptions& options) 
  SessionFactory* factory;
  const Status s = SessionFactory::GetFactory(options, &factory);
  if (!s.ok()) 
    LOG(ERROR) << s;
    return nullptr;
  
  return factory->NewSession(options);
在代码中,我们可以看到通过factory来newSession,对单机来说也就是在前面提到的DirectSession
  Session* NewSession(const SessionOptions& options) override 
    // Must do this before the CPU allocator is created.
    if (options.config.graph_options().build_cost_model() > 0) 
      EnableCPUAllocatorFullStats(true);
    
    std::vector<Device*> devices;
    const Status s = DeviceFactory::AddDevices(
        options, "/job:localhost/replica:0/task:0", &devices);
    if (!s.ok()) 
      LOG(ERROR) << s;
      return nullptr;
    

    DirectSession* session =
        new DirectSession(options, new DeviceMgr(devices), this);
    
      mutex_lock l(sessions_lock_);
      sessions_.push_back(session);
    
    return session;
  

2.2 并行计算

对Tensorflow的每个运行(op)都是需要进行计算的,对同一个会话来说,为了快速计算需要将op进行并行计算,对集群来说就是集群运算,而对单机版来说就是使用多线程来进行运算,也就是常说的线程池。

接下的博客主要是增对单机的并行运算,也就是directsession中的线程池

在tensorflow中有三种session和线程池的关系

  1. 单个会话可以设置多个线程池,在初始化会话池的时候,会依据sessionoptions的配置,读取多个线程池的配置,生成多个线程池的vector, 如果 thread_pool_options. global_name为空,代表是自己owned的需要自己关闭
  2. 单个会话设置单个线程池,在初始化会话池的时候,会依据sessionoptions的配置use_per_session_threads,读取单线程池的配置,生成单个会话相关的独立线程池, 需要自己关闭
  3. 多个会话共享相同的线程池,在初始化会话池的时候,创建所有会话共享的线程池,该线程池是全局共享,无法关闭。

在config.proto protocol buffer我们可以看到定义的配置协议的格式: ConfigProto,ThreadPoolOptionProto

message ThreadPoolOptionProto 
  // The number of threads in the pool.
  //
  // 0 means the system picks a value based on where this option proto is used
  // (see the declaration of the specific field for more info).
  int32 num_threads = 1;

  // The global name of the threadpool.
  //
  // If empty, then the threadpool is made and used according to the scope it's
  // in - e.g., for a session threadpool, it is used by that session only.
  //
  // If non-empty, then:
  // - a global threadpool associated with this name is looked
  //   up or created. This allows, for example, sharing one threadpool across
  //   many sessions (e.g., like the default behavior, if
  //   inter_op_parallelism_threads is not configured), but still partitioning
  //   into a large and small pool.
  // - if the threadpool for this global_name already exists, then it is an
  //   error if the existing pool was created using a different num_threads
  //   value as is specified on this call.
  // - threadpools created this way are never garbage collected.
  string global_name = 2;
;

message ConfigProto

  // Map from device type name (e.g., "CPU" or "GPU" ) to maximum

  // number of devices of that type to use.  If a particular device

  // type is not found in the map, the system picks an appropriate

  // number.

  map<string, int32> device_count = 1;

  // The execution of an individual op (for some op types) can be

  // parallelized on a pool of intra_op_parallelism_threads.

  // 0 means the system picks an appropriate number.

  int32 intra_op_parallelism_threads = 2;

  // Nodes that perform blocking operations are enqueued on a pool of

  // inter_op_parallelism_threads available in each process.

  //

  // 0 means the system picks an appropriate number.

  //

  // Note that the first Session created in the process sets the

  // number of threads for all future sessions unless use_per_session_threads is

  // true or session_inter_op_thread_pool is configured.

  int32 inter_op_parallelism_threads = 5;

  // If true, use a new set of threads for this session rather than the global

  // pool of threads. Only supported by direct sessions.

  //

  // If false, use the global threads created by the first session, or the

  // per-session thread pools configured by session_inter_op_thread_pool.

  //

  // This option is deprecated. The same effect can be achieved by setting

  // session_inter_op_thread_pool to have one element, whose num_threads equals

  // inter_op_parallelism_threads.

  bool use_per_session_threads = 9;

  // This option is experimental - it may be replaced with a different mechanism

  // in the future.

  //

  // Configures session thread pools. If this is configured, then RunOptions for

  // a Run call can select the thread pool to use.

  //

  // The intended use is for when some session invocations need to run in a

  // background pool limited to a small number of threads:

  // - For example, a session may be configured to have one large pool (for

  // regular compute) and one small pool (for periodic, low priority work);

  // using the small pool is currently the mechanism for limiting the inter-op

  // parallelism of the low priority work.  Note that it does not limit the

  // parallelism of work spawned by a single op kernel implementation.

  // - Using this setting is normally not needed in training, but may help some

  // serving use cases.

  // - It is also generally recommended to set the global_name field of this

  // proto, to avoid creating multiple large pools. It is typically better to

  // run the non-low-priority work, even across sessions, in a single large

  // pool.

  repeated ThreadPoolOptionProto session_inter_op_thread_pool = 12;

  // Assignment of Nodes to Devices is recomputed every placement_period

  // steps until the system warms up (at which point the recomputation

  // typically slows down automatically).

  int32 placement_period = 3;

  // When any filters are present sessions will ignore all devices which do not

  // match the filters. Each filter can be partially specified, e.g. "/job:ps"

  // "/job:worker/replica:3", etc.

  repeated string device_filters = 4;

  // Options that apply to all GPUs.

  GPUOptions gpu_options = 6;

  // Whether soft placement is allowed. If allow_soft_placement is true,

  // an op will be placed on CPU if

  //   1. there's no GPU implementation for the OP

  // or

  //   2. no GPU devices are known or registered

  // or

  //   3. need to co-locate with reftype input(s) which are from CPU.

  bool allow_soft_placement = 7;

  // Whether device placements should be logged.

  bool log_device_placement = 8;

  // Options that apply to all graphs.

  GraphOptions graph_options = 10;

  // Global timeout for all blocking operations in this session.  If non-zero,

  // and not overridden on a per-operation basis, this value will be used as the

  // deadline for all blocking operations.

  int64 operation_timeout_in_ms = 11;

  // Options that apply when this session uses the distributed runtime.

  RPCOptions rpc_options = 13;

  // Optional list of all workers to use in this session.

  ClusterDef cluster_def = 14;

  // If true, any resources such as Variables used in the session will not be

  // shared with other sessions.

  bool isolate_session_state = 15;

  // Next: 16

;

而关于单个会话创建多个线程池,主要适用于在会话运行的过程中,可以主动选择不同的线程池,还记得在调用session.run的时候可以传递runoption么?我们还是直接来看协议

message RunOptions 
  // TODO(pbar) Turn this into a TraceOptions proto which allows
  // tracing to be controlled in a more orthogonal manner?
  enum TraceLevel 
    NO_TRACE = 0;
    SOFTWARE_TRACE = 1;
    HARDWARE_TRACE = 2;
    FULL_TRACE = 3;
  
  TraceLevel trace_level = 1;

  // Time to wait for operation to complete in milliseconds.
  int64 timeout_in_ms = 2;

  // The thread pool to use, if session_inter_op_thread_pool is configured.
  int32 inter_op_thread_pool = 3;

  // Whether the partition graph(s) executed by the executor(s) should be
  // outputted via RunMetadata.
  bool output_partition_graphs = 5;

  // EXPERIMENTAL.  Options used to initialize DebuggerState, if enabled.
  DebugOptions debug_options = 6;

  // When enabled, causes tensor alllocation information to be included in
  // the error message when the Run() call fails because the allocator ran
  // out of memory (OOM).
  //
  // Enabling this option can slow down the Run() call.
  bool report_tensor_allocations_upon_oom = 7;

  reserved 4;

就是参数inter_op_thread_pool,在tensorflow中通讯协议,配置都是基于google 的protocol buffer的,所以对象的相关函数和代码,是通过编译协议后长生的,比如:

  thread::ThreadPool* pool =
      thread_pools_[run_options.inter_op_thread_pool()].first;

中的inter_op_thread_pool函数,这个在源码中无法找到,tensorflow在编译过程中会基于config.proto,自动生成c++的代码 目录在genfiles/tensorflow/core/protobuf/config.pb.h 和config.pb.cc

2.2.1 线程池的线程数

int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) 
  const int32 t = options.config.inter_op_parallelism_threads();
  if (t != 0) return t;
  // Default to using the number of cores available in the process.
  return port::NumSchedulableCPUs();

通过配置中的inter_op_parallelism_threads,在多个线程池的化的情况下,读取的就是每个线程池的num_threads了,如果没有配置,那么默认的数量将是系统有效的cpu数目

2.2.2 线程池的实现

tensowflow的线程池的实现是调用Eigen的线程池

struct ThreadPool::Impl : Eigen::ThreadPoolTempl<EigenEnvironment> 
  Impl(Env* env, const ThreadOptions& thread_options, const string& name,
       int num_threads, bool low_latency_hint)
      : Eigen::ThreadPoolTempl<EigenEnvironment>(
            num_threads, low_latency_hint,
            EigenEnvironment(env, thread_options, name)) 

  void ParallelFor(int64 total, int64 cost_per_unit,
                   std::function<void(int64, int64)> fn) 
    CHECK_GE(total, 0);
    CHECK_EQ(total, (int64)(Eigen::Index)total);
    Eigen::ThreadPoolDevice device(this, this->NumThreads());
    device.parallelFor(
        total, Eigen::TensorOpCost(0, 0, cost_per_unit),
        [&fn](Eigen::Index first, Eigen::Index last)  fn(first, last); );
  
;






以上是关于Tensorflow 源码分析-会话与线程池之间的关系的主要内容,如果未能解决你的问题,请参考以下文章

JAVA线程池原理与源码分析

nginx源码分析——线程池

十:并发编程之Executor线程池原理与源码解读

Java并发编程(十八):ThreadPoolExecutor总结与源码深度分析

Java并发编程(十八):ThreadPoolExecutor总结与源码深度分析

《Elasticsearch 源码解析与优化实战》第16章:ThreadPool模块分析