mean() 得到了一个意外的关键字参数“dtype”!

Posted

技术标签:

【中文标题】mean() 得到了一个意外的关键字参数“dtype”!【英文标题】:mean() got an unexpected keyword argument 'dtype'! 【发布时间】:2017-07-12 10:13:48 【问题描述】:

我正在尝试使用 Intel Bigdl 实现图像分类。它使用 mnist 数据集进行分类。因为,我不想使用 mnist 数据集,所以我编写了以下替代方法:

Image Utils.py

from StringIO import StringIO
from PIL import Image
import numpy as np
from bigdl.util import common
from bigdl.dataset import mnist
from pyspark.mllib.stat import Statistics

def label_img(img):
    word_label = img.split('.')[-2].split('/')[-1]
    print word_label
    # conversion to one-hot array [cat,dog]
    #                            [much cat, no dog]
    if "jobs" in word_label: return [1,0]
    #                             [no cat, very doggo]
    elif "zuckerberg" in word_label: return [0,1]

    # target is start from 0,

def get_data(sc,path):
    img_dir = path
    train = sc.binaryFiles(img_dir + "/train")
    test = sc.binaryFiles(img_dir+"/test")
    image_to_array = lambda rawdata: np.asarray(Image.open(StringIO(rawdata)))

    train_data = train.map(lambda x : (image_to_array(x[1]),np.array(label_img(x[0]))))
    test_data = test.map(lambda x : (image_to_array(x[1]),np.array(label_img(x[0]))))

    train_images = train_data.map(lambda x : x[0])
    test_images = test_data.map((lambda x : x[0]))
    train_labels = train_data.map(lambda x : x[1])
    test_labels = test_data.map(lambda x : x[1])

    training_mean = np.mean(train_images)
    training_std = np.std(train_images)
    rdd_train_images = sc.parallelize(train_images)
    rdd_train_labels = sc.parallelize(train_labels)
    rdd_test_images = sc.parallelize(test_images)
    rdd_test_labels = sc.parallelize(test_labels)

    rdd_train_sample = rdd_train_images.zip(rdd_train_labels).map(lambda (features, label):
                                        common.Sample.from_ndarray(
                                        (features - training_mean) / training_std,
                                        label + 1))
    rdd_test_sample = rdd_test_images.zip(rdd_test_labels).map(lambda (features, label):
                                        common.Sample.from_ndarray(
                                        (features - training_mean) / training_std,
                                        label + 1))

    return (rdd_train_sample, rdd_test_sample)

现在,当我尝试使用如下真实图像获取数据时:

Classification.py

import pandas
import datetime as dt

from bigdl.nn.layer import *
from bigdl.nn.criterion import *
from bigdl.optim.optimizer import *
from bigdl.util.common import *
from bigdl.dataset.transformer import *
from bigdl.dataset import mnist
from imageUtils import get_data

from StringIO import StringIO
from PIL import Image
import numpy as np

init_engine()

path = "/home/fusemachine/Hyper/person"
(train_data, test_data) = get_data(sc,path)
print train_data.count()
print test_data.count()

我收到以下错误

TypeError Traceback(最近一次调用>last) 在 ()

2 # 获取MNIST并将其存储到Sample的RDD中,请相应编辑“mnist_path”。

3 path = "/home/fusemachine/Hyper/person"

----> 4 (train_data, test_data) = get_data(sc,path)

5 打印 train_data.count()

6 打印 test_data.count()

/home/fusemachine/Downloads/dist-spark-2.1.0-scala-2.11.8-linux64-0.1.1-dist/imageUtils.py in get_data(sc, path)

31 test_labels = test_data.map(lambda x : x[1])

---> 33 training_mean = np.mean(train_images)

34 training_std = np.std(train_images)

35 rdd_train_images = sc.parallelize(train_images)

/opt/anaconda3/lib/python2.7/site-packages/numpy/core/fromnumeric.pyc in mean(a, axis, dtype, out, keepdims)

2884 次通过

2885 其他:

-> 2886 返回均值(轴=轴,dtype=dtype,out=out,**kwargs)

2887

2888 返回_methods._mean(a,axis=axis,dtype=dtype,

TypeError: mean() 得到了一个意外的关键字参数 'dtype'

我想不出解决办法。还有其他 mnist 数据集的替代方案。这样我们就可以直接处理真实的 Image 了? 谢谢

【问题讨论】:

【参考方案1】:

train_images 是一个 rdd,你不能在一个 rdd 上应用 numpy mean。一种方法是进行 collect() 并在此之上应用 numpy 均值,

 train_images = train_data.map(lambda x : x[0]).collect()
 training_mean = np.mean(train_images)

或rdd.mean()

  training_mean = train_images.mean()

【讨论】:

你能详细说明什么是“a rdd”吗?这看起来是正确的解释,但让我印象深刻的是“rdd”代码的改进领域——它应该接受dtype 参数以与numpy 兼容,即使它在传递不受支持的值时抛出NotImplementedError rdd 是 spark 的基础数据结构。更多详情请参考dzone.com/articles/what-is-rdd-in-spark-and-why-do-we-need-it

以上是关于mean() 得到了一个意外的关键字参数“dtype”!的主要内容,如果未能解决你的问题,请参考以下文章

Scikit 管道参数 - fit() 得到了一个意外的关键字参数“gamma”

fit() 得到了一个意外的关键字参数“标准”

SQLAlchemy:execute() 得到了一个意外的关键字参数

get() 得到了一个意外的关键字参数“pk”:django

TypeError: line() 得到了一个意外的关键字参数“标记”

Django - error_403() 得到了一个意外的关键字参数“异常”